1일 1PS 42일차!

📚 문제


골드 2

  • 구현 문제

💡 풀이 과정


  1. 원판의 수를 2차원 배열에 저장하고, 각 판마다 회전한 횟수를 저장할 배열을 초기화한다.
  2. x 배수의 판에 회전하는 숫자를 move_index 에 저장한다.
  3. 가장 안쪽 원판부터 양 옆을 확인하고 마지막 원판이 아니라면 자신의 바깥쪽 원판에 동일한 위치의 숫자를 비교한다.
  4. 같은 숫자가 있으면 target 에 [i, j] 를 저장한다.
  5. target 이 존재하면 해당 위치의 숫자들을 모두 0으로 초기화한다.
  6. target 이 없다면 현재 남아있는 숫자들의 평균을 구해 +1 혹은 -1 을 처리한다.
  7. 남은 원판의 총 합을 구한다.

📌포인트


  • move_index
    • 회전 수를 move_index 에 따로 기록해서 j + move_index[i] 를 통해 회전한 번째 수에 접근했다.
    • 이 과정을 deque 의 rotate() 함수를 사용해서 실제 배열을 쉽게 회전시킬 수 있다.
    • 이 수는 회전 되었다라는 의미이므로 같은 위치에 있는 수를 비교하기 위해서는 ((j - move_index[i]) % M) 이처럼 뺄셈을 해야한다.
  • target
    • target 에 대해서 확실한 초기화가 필요하다.
  • total
    • total -= len(target) 으로 남아있는 수의 개수를 세었는데, 첫 번째 원판이 [1, 1, 1] 일 경우 target 은 [(0, 0), (0, -1), (0, 1), (0, 2)] 로 -1 번째 인덱스가 중첩될 수 있다.
    • 따라서 target 저장시에도 (i, (j - 1) % M ) 으로 저장했다.
  • Zero Division
    • 원판의 모든 숫자가 지워진 상태에서 실행 횟수가 남았다면 total = 0 이므로 에러가 발생할 수 있다.
  • plate[tx][ty] = 0
    • 처음에는 원판의 수를 지움 = 해당 숫자를 0 으로 설정에 위화감을 느꼈다.
    • 처음 에러가 발생하자 이 부분이 문제인 줄 알고 bool 2차원 배열을 생성했다.
    • 하지만 조건에서 숫자들은 1 이상이며, 숫자들의 평균은 1보다 작아질 수 없기 때문에 지워지는 0 과 실제 숫자 0 이 겹칠 일이 없었다.

💻 코드



N, M, T = map(int, input().split())
plate = [list(map(int, input().split())) for _ in range(N)]
move_index = [0] * N
total = N * M
target = set()

for _ in range(T):
    x, d, k = map(int, input().split())
    if d == 0:
        for i in range(x, N + 1, x):
            move_index[i - 1] = (move_index[i - 1] + k) % M
    else:
        for i in range(x, N + 1, x):
            move_index[i - 1] = (move_index[i - 1] - k) % M

    target.clear()
    for i in range(N):
        for j in range(M):
            if plate[i][j - 1] == plate[i][j] != 0:
                target.add((i, (j-1) % M))
                target.add((i, j))
            if i < N - 1 and plate[i][(j - move_index[i]) % M] == plate[i + 1][(j - move_index[i + 1]) % M] != 0:
                target.add((i, (j - move_index[i]) % M))
                target.add((i + 1, (j - move_index[i + 1]) % M))

    if target:
        total -= len(target)
        for tx, ty in target:
            plate[tx][ty] = 0
    else:
        if total == 0:
            break
        avg = sum(sum(subplate) for subplate in plate) / total
        for i in range(N):
            for j in range(M):
                if plate[i][j] != 0 and plate[i][j] > avg:
                    plate[i][j] -= 1
                elif 0 < plate[i][j] < avg:
                    plate[i][j] += 1

print(sum(sum(subplate) for subplate in plate))



업데이트:

댓글남기기