Hello, Freakin world!

구간 트리(Segment Tree) 본문

알고리즘/자료구조

구간 트리(Segment Tree)

johnna_endure 2020. 9. 9. 16:13

구간 트리란?

구간 트리의 정체는 이진 트리입니다. 구간 트리 라는 이름은 그저 특정 용도로 사용되는 이진 트리의 한 형태를 나타낼 뿐입니다.

대게 구간 트리는 일차원 배열 상의 특정 구간에 대한 요청에 빠르게 대답하기 위해 사용합니다.

 

예를 들어, 1 2 3 4 5 라는 숫자 배열이 있다고 하겠습니다. 이 배열의 특정 구간의 합을 구한다고 할 때, 가장 간단한 방법은 그 특정 구간을 순회하면서 값을 더해나가는 겁니다. 이 경우 시간 복잡도는 O(n)이 됩니다.

만약 배열의 길이가 10억쯤 된다면 어떨까요? 컴퓨터가 1초에 1억번 정도의 루프를 돌린다고 하면 10초가 걸리겠네요.

그리고 또 이 합연산이 여러번 호출될 가능성이 있다면 어떨까요? 이 연산은 매번 합을 새로 구하기 때문에 문제가 됩니다.

 

구간 트리는 위와 같은 문제에 해법이 될 수 있습니다. log(n)의 시간 복잡도를 가지는 연산을 통해 한번만 구간 트리를 초기화하면 어떤 질의라도 log(n)의 복잡도로 해결할 수 있습니다.

 

순수하게 구간 트리를 이용해 답을 찾는 단계은 두 가지 과정을 거칩니다. 초기화 과정과 질의 과정입니다.

그리고 이외에 원본 배열이 수정될 경우 구간 트리 log(n)의 복잡도로 수정하는 과정이 존재합니다. 수정해야 될 부분이 많을 경우엔 그냥 구간 트리를 다시 초기화하는게 나을 수도 있지만, 수정되는 부분이 작은 경우에는 도움이 됩니다.  

 

아래부터는 부분합을 구하는 문제에 대해 구간 트리를 구현해보겠습니다. 그럼 우선 초기화 과정부터~

 

 

구간 트리의 초기화

 

 

 

초기화 자체는 간단합니다. 

구간을 반으로 나눠서 왼쪽 구간의 합은 왼쪽 자식 노드에 저장하고 오른쪽 구간의 합은 오른쪽 자식 노드에 저장. 그리고 현재 노드는 왼쪽 자식노드의 부분합 + 오른쪽 자식 노드 부분합 을 저장합니다.

 

	/*
	left : 배열 범위의 왼쪽
	right : 배열 범위의 오른쪽
	node : segmentTree 배열 상의 노드(인덱스)

	배열의 구간 합을 segmentTree에 저장.
	*/
	private static long initSegmentTree(int left, int right, int node) {
		if(left == right) return segmentTree[node] = numbers[left];

		int mid = (left + right)/2;
		long leftSum = initSegmentTree(left, mid, 2*node);
		long rightSum = initSegmentTree(mid+1, right, 2*node+1);

		return segmentTree[node] = leftSum + rightSum;
	}

 

참고) 이진 트리에서는 노드를 일차원 배열로 나타낼 때 왼쪽, 오른쪽 자식 노드를 각각 2*node,  2*node+1로 나타낼 수 있습니다 .

 

 

구간 트리의 질의 처리

 

 

위 그림처럼 배열의 [3~5] 범위의 구간합을 구하려고 할 때, 이미 초기화 해둔 구간 트리를 이용할 수 있습니다.

 

일단 구현코드를 먼저 살펴보겠습니다.

 

	/*
	targetLeft : 부분합을 찾기 위한 원본 배열 범위의 왼쪽 경계
	targetLeft : 부분합을 찾기 위한 원본 배열 범위의 오른쪽 경계
	node : 구간트리 배열의 노드(인덱스)
	nodeLeft : 노드에 해당하는 원본 배열 범위의 왼쪽 경계
	nodeRight : 노드에 해당하는 원본 배열 범위의 오른쪽 경계
	 */
	private static long querySum(int targetLeft, int targetRight,
	                            int node, int nodeLeft, int nodeRight) {
		//두 구간이 겹치지 않으면 무시한다.
		if(targetRight < nodeLeft || targetLeft > nodeRight) return 0;
		//node가 표현하는 범위가 target의 범위에 완전히 포함되는 경우
		if(targetLeft <= nodeLeft  && nodeRight <= targetRight) return segmentTree[node];

		//겹치지만 완전히 포함하지 않는 경우
		int mid = (nodeLeft+nodeRight)/2;
		long leftSum = querySum(targetLeft, targetRight, 2*node, nodeLeft, mid);
		long rightSum = querySum(targetLeft, targetRight, 2*node+1, mid+1, nodeRight);

		return leftSum + rightSum;
	}
    

 

위의 구현에서는 구간 트리를 1차원 배열에 표현했기 때문에 구간 트리 배열의 각 노드에는 부분합만 저장되어 있습니다.

하지만 개념적으로 각 노드는 해당 구간 정보를 가지고 있어야 합니다. 그래서 위 구현에서는 재귀함수 파라미터에 node, nodeLeft, nodeRight를 계산해서 넘기도록 하고 있습니다. 

 

합을 구하는 방식은 원하는 구간(타겟 구간)과 해당 노드의 구간을 비교하는 것에서 시작됩니다.

 

1. 타겟 구간과 노드 구간이 전혀 겹치지 않는 경우

2. 타겟 구간에 노드 구간이 완전히 포함되는 경우

3. 타겟 구간에 노드 구간이 걸치는 경우

 

1의 경우는 합에서 제외할 것이기 때문에 0을 반환합니다.

2의 경우는 구간 노드에 저장된 값을 반환하고, 3의 경우는 다시 구간을 쪼갭니다.

 

 

구간 트리의 수정

배열의 하나의 값을 변경하는 경우, 다음과 같이 구간 트리를 수정할 수 있습니다.

 

	/*
	index : 수정하려고 하는 원본 배열 요소의 인덱스
	newVal : 새로운 값
	node : 구간 트리의 노드 인덱스
	nodeLeft : 구간 트리 노드에 해당하는 원본 배열 범위의 왼쪽 경계
	nodeLeft : 구간 트리 노드에 해당하는 원본 배열 범위의 오른쪽 경계
	
	 */
	private static long update(int index, long newVal,
	                          int node, int nodeLeft, int nodeRight) {
		if(index < nodeLeft || index > nodeRight) return segmentTree[node];

		if(nodeLeft == nodeRight) return segmentTree[node] = newVal;
		int mid = (nodeLeft+nodeRight)/2;
		long leftSum = update(index, newVal, 2*node, nodeLeft, mid);
		long rightSum = update(index, newVal, 2*node+1, mid+1, nodeRight);
		return segmentTree[node] = leftSum + rightSum;
	}

 

방식은 트리의 노드를 방문하면서 수정하려는 index가 포함되는 노드를 찾아서 재귀함수를 따라 리프까지 간 뒤,

값을 수정하고 재귀함수를 리턴하면서 합들을 수정하는 방식입니다.

 

 

 

Comments