1일 1PS 83일차!

📚 문제


[골드 2] - 기업투자

💡 풀이 과정


  • 2차원 DP 문제이다.
투자 액수 기업 A 기업 B 기업 C
1 5 1 4
2 6 5 5
3 7 9 6
4 8 15 7
  • 위 투자금 표를 2차원 배열인 Table 이라 부르겠다.
  • 1만원을 투자할 때 A, B, C 회사 순서대로 최대 이익금을 계산한다.
    • A 회사 (dp[1][1]) 는 Table[1][1]
    • B 회사 (dp[1][2]) 는 max(dp[1][1], Table[1][2])
    • C 회사 (dp[1][3]) 는 max(dp[1][2], Table[1][3])
    • 회사 하나씩 추가해가며 C 회사 이후에는 A, B, C 모두 종합해 최대 이익금이 계산 가능하다.
  • 2만원 투자 시
    • A 회사 (dp[2][1]) 는 Table[2][1]
    • B 회사 (dp[2][2]) 는 max(dp[2][1], Table[2][2])
      • 추가로 B 회사에 1만원 투자 시 이익 + 기존 기업 (A 회사) 에 1만원 투자할 때 최대 이익과 비교가 필요하다.
      • Table[2][1] + dp[1][1] 과 비교
    • C 회사 (dp[2][3]) 는 max(dp[2][2], Table[2][3])
      • 동일하게 C 회사에 1만원 투자 시 이익 + 기존 기업 (A, B 회사) 에 1만원 투자 시 최대 이익과 비교
      • Table[3][1] + dp[2][1] 과 비교
  • 3만원 투자 시
    • 대부분의 내용이 동일하지만 C 회사만 다시 보자
    • 이번에는 (C 회사에 2만원 투자 이익, 기존 회사의 1만원 최대 이익), (C 회사에 1만원 투자 이익, 기존 회사의 2만원 최대 이익) 두 가지를 비교해야한다.
    • max(dp[3][3], Table[3][2] + dp[2][1], Table[3][1] + dp[2][2])
  • 어떤 기업에 투자했는 지 정보도 계속해서 기록해야한다.
  • dp 와 동일하지만 투자한 금액을 기업별로 배열에 저장해야한다.
  • 최대 이익을 찾으면 기존 회사 투자를 copy 하고 현재 회사의 투자금을 업데이트하면 된다.

📌포인트


  • DP
    • 첫 행부터 연속적인 처리를 위해 0으로 채워진 행과 열을 가진다.

💻 코드



import sys

N, M = map(int, sys.stdin.readline().split())
invest_table = [list(map(int, sys.stdin.readline().split())) for _ in range(N)]
temp = [[0] * (M + 1)]
invest_table.insert(0, temp)

# dp 연속 처리를 위해 0으로 초기화된 첫 행과 열을 추가

dp = [[0 for _ in range(M + 1)] for _ in range(N + 1)]

# 어떤 기업에 얼마씩 투자했는 지 기록
dp_invest = [[[0 for _ in range(M + 1)] for _ in range(M + 1)] for _ in range(N + 1)]

for i in range(1, N + 1):
    for j in range(1, M + 1):
        # invest Table 로 초기화. 해당 회사에 i 만원 전부 투자
        dp[i][j] = invest_table[i][j]
        dp_invest[i][j][j] = i

        # j - 1 과 비교해 더 크다면 j - 1 의 이익을 그대로 사용
        if dp[i][j] < dp[i][j-1]:
            dp[i][j] = dp[i][j-1]
            dp_invest[i][j] = dp_invest[i][j - 1].copy()

        # 현재 회사에 i - 1, i - 2, ... , 0 만원 투자했을 때 이익을 비교
        # 이때 dp[k][j - 1] 이 현재 회사에서 k 만원으로 가장 큰 이익을 얻을 수 있는 경우
        for k in range(i):
            if dp[i][j] < invest_table[i - k][j] + dp[k][j - 1]:
                dp[i][j] = invest_table[i - k][j] + dp[k][j - 1]
                dp_invest[i][j] = dp_invest[k][j - 1].copy()
                dp_invest[i][j][j] = i - k


print(dp[-1][-1])
print(*dp_invest[-1][-1][1:], sep=" ")

업데이트:

댓글남기기