문제 링크: https://www.acmicpc.net/problem/1026

 

1026번: 보물

첫째 줄에 N이 주어진다. 둘째 줄에는 A에 있는 N개의 수가 순서대로 주어지고, 셋째 줄에는 B에 있는 수가 순서대로 주어진다. N은 50보다 작거나 같은 자연수이고, A와 B의 각 원소는 100보다 작거

www.acmicpc.net

 

이 문제를 풀면서 멘탈이 많이 갈려나갔다... 처음부터 방향을 잘못잡고 쉬운문제를 세상에서 가장 어렵게 풀려고 했던 것이다.

내가 이 문제를 풀면서 처음 한 생각은 바로 순차탐색-가지치기 이다. 세상에나...
사실 특정 알고리즘을 코드로 구현하는 것조차 익숙하지 않은 나에게 순차탐색과 가지치기를 구현하는 것은 쉬운일이 아니었고, 2시간을 갈아넣어서 풀었더니 시간초과라고 한다. 우선 나의 사고 과정을 열거해보겠다.

  1. B의 순서를 고정해두고 들어갈 수 있는 모든 A를 선택하는 형태를 보니 이거 그래프모양이 나오네?
  2. 그래프를 그려서 순차탐색을 해보려고 했지만 너무 오래걸릴 것 같아서 중간에 가지치기를 해봐야겠다.
  3. 가지를 어떤 기준으로 칠까? 아래로 갈수록 결과값의 예측이 가능하네? 그럼 각 노드에서 나올 수 있는 최소값을 기준으로 가지를 치자
  4. 최솟값을 구하는 방식으론 뭘할까? A나 B 중에서 최솟값을 모두 곱하고 더하면 이거보단 무조건 큰 결과가 나오게 되어있으니까 이걸로 하자

이런 생각으로 코드를 짰다. 심지어 코드도 개판이다. 이 코드는 더보기에 첨부한다.

더보기
import sys

input = sys.stdin.readline


def initial_sort(A, B):
    mydict = dict(zip(sorted(B), sorted(A, reverse=True)))
    return [mydict[b] for b in B]


def graph_down(A, B, depth, max_depth, last_value, best_value):
    if depth == max_depth-1:
        return B[0] * A[0] + last_value

    min_value = get_expectation(A.copy(), B.copy(), last_value)
    depth_target_b = B[0]
    next_depth_b = B[1:].copy()
    for i in range(len(A)):
        if best_value <= min_value[i]:
            continue
        a_copy = A.copy()
        this_value = last_value + a_copy.pop(i) * depth_target_b
        evaluated_value = graph_down(a_copy, next_depth_b, depth + 1, max_depth, this_value, best_value)
        best_value = min(evaluated_value, best_value)
    return best_value


def get_expectation(A, B, last_value):
    min_value = []
    b_target = B.pop(0)
    for _ in range(len(A)):
        a_target = A.pop(0)
        expected_from_min_b = sum(map(lambda x: x * min(B), A))
        expected_from_min_a = sum(map(lambda x: x * min(A), B))
        temp_ex = max(expected_from_min_a, expected_from_min_b)
        min_value.append(a_target * b_target + last_value + temp_ex)
        A.append(a_target)
    return min_value


depth = int(input())
A = list(map(int, input().strip().split()))
B = list(map(int, input().strip().split()))

A = initial_sort(A, B)
print(graph_down(A, B, 0, depth, 0, float('inf')))

 

그리고 시간초가를 확인하고 구글에 정답을 검색해보는데,,, 세상에나 그냥 가장 큰 숫자와 가장 작은 숫자를 곱해서 더하면 되는거였다. 그 사실을 알고 코드를 짜보니 6줄 나왔다...

import sys

input = sys.stdin.readline
depth = int(input())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
print(sum(map(lambda a: a[0] * a[1], zip(sorted(B), sorted(A, reverse=True)))))

 

지금도 잘 이해할 수 없는 부분이 있다.

 

A의 가장 큰 수와 B의 가장 작은 수가 곱해지고, A의 2번쨰 큰 수와 B의 2번째 작은 수가 곱해지고를 반복했을 때 나오는 결과값이 가장 작다고 확신할 수 있는 근거는 무엇인가?

 

사실 직관적으로는 이해할 수 있다. 나도 어렴풋 이러한 원리를 생각해서 순차탐색을 할 때 A배열의 최초정렬을 이와 같은 방식으로 했으니까. 하지만 이러한 방식에 논리적인 근거가 없다고 생각되어서 완전탐색을 했는데... 아직까지 많이 부족한 것이 느껴진다.

+ Recent posts