[백준 / BOJ] 2042번 : 구간 합 구하기 - JAVA
https://www.acmicpc.net/problem/2042
2042번: 구간 합 구하기
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄
www.acmicpc.net
첫번째 시도 : 누적 합 + HashMap
제목을 보고 가장 먼저 떠올린 것이 누적 합이었다. 수열을 받으면서 동시에 이전 누적값을 계속 더해준다면 a부터 b 사이 구간합을 구할 때 sum[b] - sum[a - 1] 로 바로 구할 수 있기 때문이다.
하지만 이 문제처럼 중간에 값이 바뀔 때마다 누적 합을 다시 구한다면 최대 N * K (배열의 최대 크기 * 구간 합을 출력하는 횟수)번의 연산이 필요하다. 이는 1,000,000 * 10,000으로 2의 10승인데, 보통 2의 9승번의 연산을 1초에 할 수 있다고 가정하니 조건인 2초 내에 불가능하다.
이에 생각을 조금 바꿔서 매번 누적 합을 새로 구하는 게 아니라 변경 사항만 저장해두고, 제시된 구간 안에 변경 사항이 있는지를 체크하도록 했다. 최대 변경 횟수 M * 최대 구간 합 횟수 K는 100,000,000으로, 1초 안에 연산이 가능하다는 판단이 들었기 때문이다.
같은 위치의 숫자를 여러 번 바꿀 수도 있기 때문에 HashMap을 사용해서 이미 나왔던 위치인지 확인했다. 그리고 구간 합을 출력해야 할 때마다 저장해둔 변동 사항을 모두 체크하며 구간 내에 있었던 변화를 반영하도록 했다.
결국 맞았습니다 를 받긴 했지만, 수행 시간이 3000ms나 되는 아주 느린 코드가 되고 말았다.
누적합 : 전체 코드 (3000ms)
public class Main {
public static void main(String[] args) throws NumberFormatException, IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken());
int M = Integer.parseInt(st.nextToken());
int K = Integer.parseInt(st.nextToken());
long[] num = new long[N + 1];
long[] sum = new long[N + 1];
// sum은 누적합 배열이다.
for (int i = 1; i <= N; i++) {
num[i] = Long.parseLong(br.readLine());
sum[i] = num[i] + sum[i - 1];
}
long[][] diff = new long[M][2];
int o = 0;
StringBuilder sb = new StringBuilder();
Map<Integer, Integer> map = new HashMap<>(); // <변동 위치, diff에 변화를 저장한 idx>
for (int i = 0; i < M + K; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
if(a == 1) {
long c = Long.parseLong(st.nextToken());
if(map.containsKey(b)) { // 변동된 적 있는 위치면
int loc = map.get(b); // 그 내용을 diff의 몇 번 idx에 저장했는지 찾고
diff[loc][1] = c - num[b]; // 해당 idx를 다시 수정해준다
} else { // 변동된 적 없는 위치면
diff[o][0] = b; // 위치
diff[o][1] = c - num[b]; // 변화값
map.put(b, o); // map에 변동 위치와 diff의 몇 번 idx에 변화값이 있는지 저장
o++;
}
} else {
int c = Integer.parseInt(st.nextToken());
long total = sum[c] - sum[b - 1]; // 일단 기존 누적합을 구하고
for (int j = 0; j < o; j++) { // 지금까지 있던 모든 변동 사항을 돌면서
if(diff[j][0] >= b && diff[j][0] <= c) { // 구간 내에 일어난 일이면
total += diff[j][1]; // 반영해준다
}
}
sb.append(total).append("\n");
}
}
System.out.println(sb);
}
}
두 번째 시도 : 세그먼트 트리
3000ms에 큰 충격을 받고 원래는 어떻게 푸는 문제인지 확인했다. 알고리즘 분류를 보면 "세그먼트 트리"라고 적혀 있는데, 처음 들어보는 개념이라서 구글링을 통해 완성할 수 있었다.
세그먼트 트리는 주어진 수열을 이진 트리의 leaf로 두고, 부모 노드는 자식 노드 두 개의 합을 갖도록 하는 트리이다. 즉 예제를 바탕으로 그려본다면 다음과 같다.
주황색으로 표시한 게 리프 노드이다. 주어진 예시에는 수열의 원소가 5개 뿐이기 때문에 뒤의 세 칸은 비어 있다.
이 문제를 풀기 위해 세그먼드 트리에 세 가지 함수를 정의했다.
1. 트리 만들기
public long init(long[] num, int node, int start, int end) {
if(start == end) {
return tree[node] = num[start];
}
return tree[node] = init(num, node * 2, start, (start + end) / 2)
+ init(num, node * 2 + 1, (start + end) / 2 + 1, end);
}
tree의 node번째에는 start부터 end까지의 구간 합이 들어가 있다. 따라서, 만일 leaf 노드가 아니라면 바닥까지 찍고 올라오면서 리턴값을 더해주어야 한다.
이진 트리기 때문에 start와 end 구간을 절반씩 쪼개 가지게 되어, start ~ (start + end) / 2
와 (start + end) / 2 + 1 ~ end)
로 나뉜다. 나는 실수로 처음에 start ~ (start + end) / 2와 (start + end) / 2 ~ end로 나눴는데, 이러면 스택 오버플로우가 일어나는 걸 볼 수 있다.
2. 특정 위치의 값 바꾸기
public void update(int node, int start, int end, int idx, long diff) {
if(idx < start || end < idx) return;
tree[node] += diff;
if(start == end) return;
update(node * 2, start, (start + end) / 2, idx, diff);
update(node * 2 + 1, (start + end) / 2 + 1, end, idx, diff);
}
여기서도 tree의 node번째에는 start와 end 사이 구간 합이 들어가 있다. 따라서 값이 변할 위치인 idx가 start와 end 사이에 존재한다면, 구간 합에 당연히 diff 만큼의 변화가 생긴다. 이것 역시 start == end가 될 때까지, 즉, 리프 노드에 도착할 때까지 양쪽 자식에게 계속 전달한다.
단 leaf에 도달하지 않더라도, 만일 idx가 start보다 작거나, end보다 크다면(start ~ end 사이가 아니라면) 해당 노드의 구간 합에 아무런 영향도 없으므로 바로 리턴하면 된다.
3. 구간 합 구하기
public long sum(int node, int start, int end, int left, int right) {
if(start > right || end < left) return 0l;
if(left <= start && end <= right) {
return tree[node];
}
return sum(node * 2, start, (start + end) / 2, left, right)
+ sum(node * 2 + 1, (start + end) / 2 + 1, end, left, right);
}
구간 합 구하기는 1번과 2번을 짬뽕시킨 구조이다. idx를 찾는 대신 left ~ right라는 범위를 찾는다. 따라서 주어진 구간의 시작인 left보다 end가 작거나, 주어진 구간의 끝인 right보다 start가 크다면 0을 리턴한다. 더해줄 값이 없기 때문이다.
반대로, 만일 left와 right 사이에 start ~ end가 쏙 들어가 있다면 그건 start ~ end 구간이 원하는 범위 안에 포함된다는 뜻이므로 더 내려갈 것도 없이 통째로 더해준다.
둘 다 아니라면 어중간하게 걸쳐 있는 것이므로, 다시 양쪽 자식 노드로 찢어져 해당하는 구간만 찾도록 한다.
이처럼 세그먼트 트리를 사용하여 실행했을 때에는 568ms로 맞았습니다를 받을 수 있었다. 이전 3000ms에 비하면 5분의 1로 줄어든 것이다.
세그먼트 트리 전체 코드(568ms)
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class Main {
static class SegTree {
long[] tree;
public SegTree(int size) {
this.tree = new long[size * 4];
}
public long init(long[] num, int node, int start, int end) {
if(start == end) {
return tree[node] = num[start];
}
return tree[node] = init(num, node * 2, start, (start + end) / 2)
+ init(num, node * 2 + 1, (start + end) / 2 + 1, end);
}
public void update(int node, int start, int end, int idx, long diff) {
if(idx < start || end < idx) return;
tree[node] += diff;
if(start == end) return;
update(node * 2, start, (start + end) / 2, idx, diff);
update(node * 2 + 1, (start + end) / 2 + 1, end, idx, diff);
}
public long sum(int node, int start, int end, int left, int right) {
if(start > right || end < left) return 0l;
if(left <= start && end <= right) {
return tree[node];
}
return sum(node * 2, start, (start + end) / 2, left, right)
+ sum(node * 2 + 1, (start + end) / 2 + 1, end, left, right);
}
}
public static void main(String[] args) throws NumberFormatException, IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken());
int M = Integer.parseInt(st.nextToken());
int K = Integer.parseInt(st.nextToken());
long[] num = new long[N + 1];
for (int i = 1; i <= N; i++) {
num[i] = Long.parseLong(br.readLine());
}
SegTree tree = new SegTree(N);
tree.init(num, 1, 1, N);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < M + K; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
if(a == 1) {
long c = Long.parseLong(st.nextToken());
tree.update(1, 1, N, b, c - num[b]);
num[b] = c;
} else {
int c = Integer.parseInt(st.nextToken());
sb.append(tree.sum(1, 1, N, b, c)).append("\n");
}
}
System.out.println(sb);
}
}