JAVA/문제 풀이

[백준 / BOJ] 2042번 : 구간 합 구하기 - JAVA

ahue 2023. 9. 25. 23:02
728x90

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

 

첫번째 시도 : 누적 합 + HashMap

제목을 보고 가장 먼저 떠올린 것이 누적 합이었다. 수열을 받으면서 동시에 이전 누적값을 계속 더해준다면 a부터 b 사이 구간합을 구할 때 sum[b] - sum[a - 1] 로 바로 구할 수 있기 때문이다.

하지만 이 문제처럼 중간에 값이 바뀔 때마다 누적 합을 다시 구한다면 최대 N * K (배열의 최대 크기 * 구간 합을 출력하는 횟수)번의 연산이 필요하다. 이는 1,000,000 * 10,000으로 2의 10승인데, 보통 2의 9승번의 연산을 1초에 할 수 있다고 가정하니 조건인 2초 내에 불가능하다.

이에 생각을 조금 바꿔서 매번 누적 합을 새로 구하는 게 아니라 변경 사항만 저장해두고, 제시된 구간 안에 변경 사항이 있는지를 체크하도록 했다. 최대 변경 횟수 M * 최대 구간 합 횟수 K는 100,000,000으로, 1초 안에 연산이 가능하다는 판단이 들었기 때문이다.

같은 위치의 숫자를 여러 번 바꿀 수도 있기 때문에 HashMap을 사용해서 이미 나왔던 위치인지 확인했다. 그리고 구간 합을 출력해야 할 때마다 저장해둔 변동 사항을 모두 체크하며 구간 내에 있었던 변화를 반영하도록 했다.

결국 맞았습니다 를 받긴 했지만, 수행 시간이 3000ms나 되는 아주 느린 코드가 되고 말았다.

 

누적합 : 전체 코드 (3000ms)

public class Main {

    public static void main(String[] args) throws NumberFormatException, IOException {

        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());

        int N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());
        int K = Integer.parseInt(st.nextToken());

        long[] num = new long[N + 1];
        long[] sum = new long[N + 1];

        // sum은 누적합 배열이다.
        for (int i = 1; i <= N; i++) {
            num[i] = Long.parseLong(br.readLine());
            sum[i] = num[i] + sum[i - 1];
        }



        long[][] diff = new long[M][2];
        int o = 0;
        StringBuilder sb = new StringBuilder();
        Map<Integer, Integer> map = new HashMap<>(); // <변동 위치, diff에 변화를 저장한 idx>

        for (int i = 0; i < M + K; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());

            if(a == 1) {
                long c = Long.parseLong(st.nextToken());
                if(map.containsKey(b)) { // 변동된 적 있는 위치면
                    int loc = map.get(b); // 그 내용을 diff의 몇 번 idx에 저장했는지 찾고
                    diff[loc][1] = c - num[b]; // 해당 idx를 다시 수정해준다
                } else {    // 변동된 적 없는 위치면
                    diff[o][0] = b; // 위치
                    diff[o][1] = c - num[b]; // 변화값
                    map.put(b, o); // map에 변동 위치와 diff의 몇 번 idx에 변화값이 있는지 저장
                    o++;
                }
            } else {
                int c = Integer.parseInt(st.nextToken());
                long total = sum[c] - sum[b - 1]; // 일단 기존 누적합을 구하고
                for (int j = 0; j < o; j++) { // 지금까지 있던 모든 변동 사항을 돌면서
                    if(diff[j][0] >= b && diff[j][0] <= c) { // 구간 내에 일어난 일이면
                        total += diff[j][1]; // 반영해준다
                    }
                }
                sb.append(total).append("\n");
            }
        }

        System.out.println(sb);
    }

}

 

두 번째 시도 : 세그먼트 트리

3000ms에 큰 충격을 받고 원래는 어떻게 푸는 문제인지 확인했다. 알고리즘 분류를 보면 "세그먼트 트리"라고 적혀 있는데, 처음 들어보는 개념이라서 구글링을 통해 완성할 수 있었다.

세그먼트 트리는 주어진 수열을 이진 트리의 leaf로 두고, 부모 노드는 자식 노드 두 개의 합을 갖도록 하는 트리이다. 즉 예제를 바탕으로 그려본다면 다음과 같다.

주황색으로 표시한 게 리프 노드이다. 주어진 예시에는 수열의 원소가 5개 뿐이기 때문에 뒤의 세 칸은 비어 있다.

이 문제를 풀기 위해 세그먼드 트리에 세 가지 함수를 정의했다.

 

1. 트리 만들기

public long init(long[] num, int node, int start, int end) {

    if(start == end) {
        return tree[node] = num[start];
    }

    return tree[node] = init(num, node * 2, start, (start + end) / 2) 
        + init(num, node * 2 + 1, (start + end) / 2 + 1, end);
}

tree의 node번째에는 start부터 end까지의 구간 합이 들어가 있다. 따라서, 만일 leaf 노드가 아니라면 바닥까지 찍고 올라오면서 리턴값을 더해주어야 한다.

 

이진 트리기 때문에 start와 end 구간을 절반씩 쪼개 가지게 되어, start ~ (start + end) / 2(start + end) / 2 + 1 ~ end)로 나뉜다. 나는 실수로 처음에 start ~ (start + end) / 2와 (start + end) / 2 ~ end로 나눴는데, 이러면 스택 오버플로우가 일어나는 걸 볼 수 있다.

 

2. 특정 위치의 값 바꾸기

public void update(int node, int start, int end, int idx, long diff) {

    if(idx < start || end < idx) return;

    tree[node] += diff;
    if(start == end) return;

    update(node * 2, start, (start + end) / 2, idx, diff);
    update(node * 2 + 1, (start + end) / 2 + 1, end, idx, diff);

}

여기서도 tree의 node번째에는 start와 end 사이 구간 합이 들어가 있다. 따라서 값이 변할 위치인 idx가 start와 end 사이에 존재한다면, 구간 합에 당연히 diff 만큼의 변화가 생긴다. 이것 역시 start == end가 될 때까지, 즉, 리프 노드에 도착할 때까지 양쪽 자식에게 계속 전달한다.

 

단 leaf에 도달하지 않더라도, 만일 idx가 start보다 작거나, end보다 크다면(start ~ end 사이가 아니라면) 해당 노드의 구간 합에 아무런 영향도 없으므로 바로 리턴하면 된다.

 

3. 구간 합 구하기

public long sum(int node, int start, int end, int left, int right) {

    if(start > right || end < left) return 0l;

    if(left <= start && end <= right) {
        return tree[node];
    }

    return sum(node * 2, start, (start + end) / 2, left, right) 
        + sum(node * 2 + 1, (start + end) / 2 + 1, end, left, right);
}

구간 합 구하기는 1번과 2번을 짬뽕시킨 구조이다. idx를 찾는 대신 left ~ right라는 범위를 찾는다. 따라서 주어진 구간의 시작인 left보다 end가 작거나, 주어진 구간의 끝인 right보다 start가 크다면 0을 리턴한다. 더해줄 값이 없기 때문이다.

 

반대로, 만일 left와 right 사이에 start ~ end가 쏙 들어가 있다면 그건 start ~ end 구간이 원하는 범위 안에 포함된다는 뜻이므로 더 내려갈 것도 없이 통째로 더해준다.

 

둘 다 아니라면 어중간하게 걸쳐 있는 것이므로, 다시 양쪽 자식 노드로 찢어져 해당하는 구간만 찾도록 한다. 

 

이처럼 세그먼트 트리를 사용하여 실행했을 때에는 568ms로 맞았습니다를 받을 수 있었다. 이전 3000ms에 비하면 5분의 1로 줄어든 것이다.

 

세그먼트 트리 전체 코드(568ms)

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {

    static class SegTree {
        long[] tree;

        public SegTree(int size) {
            this.tree = new long[size * 4];
        }

        public long init(long[] num, int node, int start, int end) {

            if(start == end) {
                return tree[node] = num[start];
            }

            return tree[node] = init(num, node * 2, start, (start + end) / 2) 
                + init(num, node * 2 + 1, (start + end) / 2 + 1, end);
        }

        public void update(int node, int start, int end, int idx, long diff) {

            if(idx < start || end < idx) return;

            tree[node] += diff;
            if(start == end) return;

            update(node * 2, start, (start + end) / 2, idx, diff);
            update(node * 2 + 1, (start + end) / 2 + 1, end, idx, diff);
        }

        public long sum(int node, int start, int end, int left, int right) {

            if(start > right || end < left) return 0l;

            if(left <= start && end <= right) {
                return tree[node];
            }

            return sum(node * 2, start, (start + end) / 2, left, right) 
                + sum(node * 2 + 1, (start + end) / 2 + 1, end, left, right);
        }
    }
    public static void main(String[] args) throws NumberFormatException, IOException {

        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());

        int N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());
        int K = Integer.parseInt(st.nextToken());

        long[] num = new long[N + 1];
        for (int i = 1; i <= N; i++) {
            num[i] = Long.parseLong(br.readLine());
        }

        SegTree tree = new SegTree(N);
        tree.init(num, 1, 1, N);

        StringBuilder sb = new StringBuilder();

        for (int i = 0; i < M + K; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());

            if(a == 1) {
                long c = Long.parseLong(st.nextToken());
                tree.update(1, 1, N, b, c - num[b]);
                num[b] = c;
            } else {
                int c = Integer.parseInt(st.nextToken());
                sb.append(tree.sum(1, 1, N, b, c)).append("\n");
            }
        }
        System.out.println(sb);
    }

}
반응형