Algorithm

Segment Tree

부분합(Fenwick Tree)

시간복잡도

부분합 찾기, 값 변경 모두 O(logN)
notion image
[8, 15] 구간의 구간 합은 사실 부분 합만을 구한다면 필요가 없다. psum[15]를 구한다면 어차피 루트에 있는 값을 사용하면 되고, 다른 위치의 부분 합을 구할 때는 이 값을 쓸 수가 없기 때문에 굳이 저장해 둘 필요가 없다.

필요없는 구간 지우기

notion image
같은 원리로, 하나의 긴 구간 밑에 두 개의 작은 구간이 있을 때 이 두 구간 중 오른쪽 구간은 항상 지워도 된다.
  • 남은 구간의 갯수는 정확히 n개가 된다. (8 + 4 + 2 + 1 = 15)
  • 오른쪽 끝 원소들을 보면 모두 값이 다른 것을 알 수 있다. 그래서 이 대응을 이용해 1차원 배열에 각 구간의 합을 저장할 수 있다.
  • tree[i] = 오른쪽 끝 위치가 arr[i]인 구간의 합

구간 합 구하기

💡
arr[pos] 까지의 구간 합 psum[pos]를 구하고 싶으면 위 그림에서 pos에서 끝나는 구간의 합 tree[pos]를 답에 더한다. 그리고 남은 부분들을 왼쪽에서 찾아 더하면 된다.
notion image
예를 들어, psum[12] = tree[12] + tree[11] + tree[7]이다.
이제 더해야 할 구간을 어떻게 찾아야 하는지 알아내야 한다.

이진수로 표현하기

펜웍 트리는 각 숫자의 이진수 표현을 이용해 이 문제를 해결할 수 있다. 우선 이를 위해 배열 arr[]와 tree[]의 첫 원소의 인덱스를 1로 바꾸자. 모든 원소의 인덱스에 1을 더해주면 된다. 그러고 나면 특정 부분 합을 구하기 위해 더해야 할 구간 합들을 쉽게 찾을 수 있다.
💡
다음 그림을 보면 각 구간들의 길이는 오른쪽 끝에 있는 0의 개수가 하나 늘 때마다 두 배로 늘어나는 것을 확인할 수 있다.
  • 8의 이진수 표현은 1000(2)이고, 이 수의 오른쪽 끝에는 0이 세 개이므로 8에서 끝나는 구간의 길이는 2^3 = 8이다.
  • 10의 이진수 표현은 1010(2) 이고, 이 수의 오른쪽에는 0이 하나 있으므로 10에서 끝나는 구간의 길이는 2^1 = 2가 된다.
notion image

이진수 표현으로 부분 합 구간 찾기

💡
오른쪽 끝 위치의 이진수 표현에서 마지막 비트를 지우면 다음 구간을 쉽게 찾을 수 있다.
  • 예를 들어 psum[7]을 구하기 위해 더해야 하는 숫자는 7에서 끝나는 구간의 합 tree[7], 6에서 끝나는 구간의 합 tree[6], 그리고 4에서 끝나는 구간 합 tree[4]이다.
  • 이진수로 표현하면, 111(2) → 110(2) → 100(2)이 된다.
notion image

이진수 표현으로 배열 값 변경하기

💡
해당 위치의 값에 숫자를 더하고 빼는 것으로 구현 맨 오른쪽에 있는 1인 비트스스로에게 더해주는 연산을 반복하여 해당 위치를 포함하는 구간들을 모두 만날 수 있다.
  • 예를 들어 arr[5]를 3늘리고 싶다고 하면, arr[5]를 포함하는 모든 구간의 합3씩 늘려주면 된다.
  • 이때 늘려줘야 할 값들은 tree[5], tree[6], tree[8], tree[16]으로, 101(2), 110(2), 1000(2), 10000(2)이다.
  • 101(2) → 110(2) → 1000(2) → 10000(2) 순으로 이동한다.
notion image

구현

배열 값 변경

💡
맨 오른쪽에 있는 1인 비트를 스스로에게 더해주는 연산을 반복하여 해당 위치를 포함하는 구간들을 모두 만날 수 있다.
늘려줘야 할 값들은 3번 노드, 4번 노드로 이진수로 표현하면 11(2) → 100(2)로 이동한다.
pos += (pos & -pos);
notion image

구간 합 구하기

💡
psum은 이진수 표현에서 마지막 비트를 지우면서 다음 구간을 찾아가서 더해주면 된다.  오른쪽 끝 위치의 이진수 표현에서 마지막 비트를 지우면 다음 구간을 쉽게 찾을 수 있다.
예를 들어 psum[3]을 구하기 위해 더해야 하는 숫자는 3에서 끝나는 구간의 합 3번 노드, 2에서 끝나는 구간의 합 tree[2]이다.
  • 3번 노드, 2번 노드를 이진수로 표현하면 11(2) → 10(2)이다.
pos &= (pos - 1);
notion image

코드

public class FenwickTree { static int[] tree; public FenwickTree(int size) { tree = new int[size + 1]; } long sum(int pos){ long result = 0; while(pos > 0){ result += tree[pos]; pos &= (pos - 1); } return result; } void add(int pos, int val){ while(pos < tree.length){ tree[pos] += val; pos += (pos & -pos); } } }

구간 변경

(i, j)에 k 더하기
arr[i] + arr[i+1] + ··· + arr[j]에 k를 더해줘야 한다.
[i+1,  j]까지는 변화가 없고, 각 i와 j+1에 해당하는 구간에만 +k, -k를 더해주면 된다.
예를 들어, 3~4까지 +6을 더해주면 다음과 같이 값이 변경된다.
  • 3이 포함되는 노드들에 +6을 더해주고, 5가 포함되는 노드들에 -6을 더해주면 된다.
notion image

문제

구간 합 구하기

package org.example; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.util.*; public class Main { private static StringBuilder sb = new StringBuilder(); private static int N, M, K; private static long[] list; private static FenWick fenWick; static class FenWick { private long[] data; public FenWick(int size) { this.data = new long[size]; } private void update(int idx, long num) { while (idx < this.data.length) { this.data[idx] += num; idx += (idx & -idx); } } private long sum(int idx) { long allSum = 0; while (idx > 0) { allSum += this.data[idx]; idx -= (idx & -idx); } return allSum; } } public static void main(String[] args) throws IOException { System.setIn(new FileInputStream("input.txt")); BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); StringTokenizer st = new StringTokenizer(br.readLine()); // 입력 N = Integer.parseInt(st.nextToken()); M = Integer.parseInt(st.nextToken()); K = Integer.parseInt(st.nextToken()); // 초기화 list = new long[N + 1]; fenWick = new FenWick(N + 1); // N개의 수가 주어짐 for (int i = 1; i <= N; i++) { long input = Long.parseLong(br.readLine()); fenWick.update(i, input); list[i] = input; } for (int i = 0; i < M + K; i++) { st = new StringTokenizer(br.readLine()); // a가 1인 경우 b번째 수를 c로 바꾸고 if (Integer.parseInt(st.nextToken()) == 1) { int idx = Integer.parseInt(st.nextToken()); long input = Long.parseLong(st.nextToken()); fenWick.update(idx, input - list[idx]); list[idx] = input; } // a가 2인 경우에는 b번째 수부터 c번째 수까지의 합을 구하여 출력 else { int left = Integer.parseInt(st.nextToken()); int right = Integer.parseInt(st.nextToken()); sb.append((fenWick.sum(right) - fenWick.sum(left - 1)) + "\n"); } } System.out.println(sb.toString()); } }

커피숍2

package org.example.tree.fenwicktree; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.util.*; public class FenwickTree { private static StringBuilder sb = new StringBuilder(); private static int N, Q; private static long[] arr, temp; private static FenWick fenWick; static class FenWick { private long[] tree; public FenWick(int size) { this.tree = new long[size]; } private void update(int idx, long num) { while (idx < this.tree.length) { this.tree[idx] += num; idx += (idx & -idx); } } private long sum(int idx) { long allSum = 0; while (idx > 0) { allSum += this.tree[idx]; idx -= (idx & -idx); } return allSum; } } public static void main(String[] args) throws IOException { System.setIn(new FileInputStream("input.txt")); BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); StringTokenizer st = new StringTokenizer(br.readLine()); // 입력 N = Integer.parseInt(st.nextToken()); Q = Integer.parseInt(st.nextToken()); // 초기화 arr = new long[N + 1]; fenWick = new FenWick(N + 1); // N개의 수가 주어짐 st = new StringTokenizer(br.readLine()); for (int n = 1; n <= N; n++) { long input = Long.parseLong(st.nextToken()); arr[n] = input; fenWick.update(n, input); } for (int q = 0; q < Q; q++) { st = new StringTokenizer(br.readLine()); temp = new long[4]; for (int i = 0; i < 4; i++) { temp[i] = Long.parseLong(st.nextToken()); } // swap if (temp[0] > temp[1]) { long t = temp[0]; temp[0] = temp[1]; temp[1] = t; } sb.append(fenWick.sum((int) temp[1]) - fenWick.sum((int) (temp[0] - 1))).append("\n"); long diff = temp[3] - arr[(int) temp[2]]; fenWick.update((int) temp[2], diff); arr[(int) temp[2]] = temp[3]; } System.out.println(sb); } }

출처