Hello, Freakin world!

[백준 11438번][Java] LCA 2 - 희소 배열을 이용해서 LCA 구하기 본문

알고리즘/PS

[백준 11438번][Java] LCA 2 - 희소 배열을 이용해서 LCA 구하기

johnna_endure 2020. 10. 12. 20:00

www.acmicpc.net/problem/11438

 

11438번: LCA 2

첫째 줄에 노드의 개수 N이 주어지고, 다음 N-1개 줄에는 트리 상에서 연결된 두 정점이 주어진다. 그 다음 줄에는 가장 가까운 공통 조상을 알고싶은 쌍의 개수 M이 주어지고, 다음 M개 줄에는 정�

www.acmicpc.net

LCA 은 간단하게 트리 구조의 특성을 이용해 단순하게 해결했는데, 이번 문제는 효율성을 고려하지 않으면 풀 수 없었습니다.

 

영어로 LCA로 검색을 해보면 이 문제에 관해 여러 가지 풀이가 존재한다는 걸 알 수 있습니다.

cp-algorithms.com/graph/lca.html

 

Lowest Common Ancestor - O(sqrt(N)) and O(log N) with O(N) preprocessing - Competitive Programming Algorithms

Lowest Common Ancestor - $O(\sqrt{N})$ and $O(\log N)$ with $O(N)$ preprocessing Given a tree $G$. Given queries of the form $(v_1, v_2)$, for each query you need to find the lowest common ancestor (or least common ancestor), i.e. a vertex $v$ that lies on

cp-algorithms.com

 

위 사이트는 LCA의 여러 풀이 방법을 짧게 소개해놓은 글입니다. 관심있으신 분은 읽어보세용.

이 문제는 할 얘기가 길어질 것 같으니 여기서 한 템포 끊고 본론으로 들어가겠습니다.


LCA 풀이의 전체적인 개요는 다음과 같습니다.

 

1. 트리의 오일러 경로(euler tour)를 리스트에 저장합니다.

2. 이 오일러 경로를 이용해 문제를 RMQ(Range Minimum Query) 문제로 변환합니다.

3. RMQ 문제를 풉니다.

 

RMQ를 푸는 대표적인 방식으로는 세그먼트 트리가 있습니다. 

LCA 풀이 방법의 분기는 3번 과정에서 어떤 RMQ 알고리즘을 선택하느냐에 따라 갈라집니다.

대표적인 방법으로 세그먼트 트리,  희소배열(Sparse array), Sqrt-Decomposition(이 부분은 아직 살펴보지 않아서 잘 모르겠네요)이 있습니다.

 

저는 희소배열을 이용해 풀었습니다.

 

자 전체적인 큰 그림은 됐고, 이제 조금 더 자세히 살펴보도록 하겠습니다.

 

1. 트리의 오일러 경로

 

딱히 특별한 건 없습니다. 아래와 같은 트리가 있다고할 때

 

출처 : https://cp-algorithms.com/graph/lca.html

오일러 경로는 다음과 같습니다.

 

1 > 2 > 5 > 2 > 6 > 2 > 1 > 3 > 1 > 4 > 7 > 4 > 1  

 

그냥 DFS로 완전 탐색하면서 노드를 방문할때와 재귀함수를 리턴할 때 해당 노드를 저장해주면 됩니다.

그런데 오일러 경로를 어디다 써먹을까요?

 

2. 오일러 경로를 이용해 최소 공통 조상(LCA) 찾기

 

6,3의 LCA를 찾으려고 합니다. 그리고 다음의 오일러 경로 있을 때, 어떻게 LCA를 찾을 수 있을까요?

 

1 > 2 > 5 > 2 > 6 > 2 > 1 > 3 > 1 > 4 > 7 > 4 > 1  

 

일단 LCA가 6과 3의 오일러 경로 사이에 있다는 건 쉽게 눈치챌 수 있습니다. LCA는 반드시 6에서 3으로 가는 경로에 위치하기 때문입니다. 

1 > 2 > 5 > 2 > 6 > 2 > 1 > 3 > 1 > 4 > 7 > 4 > 1  

그리고 위 그래프는 또 다른 규칙이 있습니다. 위에서 아래로, 왼쪽에서 오른쪽 방향으로 갈수록 정점의 번호가 크게 매겨졌다는 겁니다(이 규칙이 문제에는 적용되지 않습니다). 이 규칙은 아주 중요합니다. 이로 인해 결국 [6~3] 구간의 최소값인 1이 LCA가 됩니다.

LCA는 [서브] 트리의 루트가 되어야 되기 때문에 부모를 따라 올라 가서 최소 번호를 가진 정점을 지난 후 다시 내려 가면서 정점의 번호가 커지기 때문입니다.

만약 어떤 두 노드 중 하나가 LCA가 되는 경우도 결국 두 정점 중 하나가 최소값이 되므로 결국 동일한 논리가 적용됩니다. 

 

앞에서 미리 말하진 않았지만 오일러 경로의 구간을 선택하는 것도 짚고 넘어가겠습니다. 위의 6,3 은 리프 노드라 오일러 경로에서 중복되지 않는 값이라 구간을 쉽게 정할 수 있었습니다. 하지만 2,4번 정점의 LCA를 구할 때는 어떤 구간을 선택해야 할까요?

사실 어떤 구간을 선택해도 상관없습니다. 하지만 구현의 편의를 위해 보통 정점이 처음 나타나는 인덱스를 선택합니다.

 

1 > 2 > 5 > 2 > 6 > 2 > 1 > 3 > 1 > 4 > 7 > 4 > 1  

 

이 구간의 최소값은 1이므로 LCA는 1이 됩니다.

 

이 문제는 정점의 번호만으로 최소값 구해서 문제를 풀 수 없습니다. 꼭 위쪽에 위치한 정점의 번호가 낮을 거라는 보장이 없기 때문입니다. 그래서 정점이 처음 발견된 순서를 따로 저장합니다.  indexToOrder, orderToIndex 라는 두 가지 배열을 만들어서 처음 발견한 순서와 정점 번호 간에 변환이 가능하도록 합시다. 

 

발견 순서로 변환하는 이유는 DFS와 트리의 특성으로 인해 부모 노드의 번호가 자식 노드의 번호보다 작을 것이 보장되기 때문입니다.

그림에서 노드를 따라가보면, 오른쪽의 트리 역시 오일러 경로 구간의 최소값이 LCA가 됨을 알 수 있습니다.

 

 

이전에 구한 오일러 경로의 정점 번호를 다시 각 발견 순서로 변환하고 가장 빨리 발견된 순서을 구합니다. 그리고 다시 그 값을 정점 번호로 변환하면 LCA가 됩니다.

자, 이제 LCA 문제가 RMQ 문제로 변했습니다. 다음은 희소 배열에 대해 설명하겠습니다.

 

3. 희소 배열로 RMQ(구간 최소 쿼리)문제 해결하기

 

영어 원문으로는 Sparse Array 라는데, 도통 의미가 와닿지 않은 단어입니다...

 

일단 이 방법은 구간의 최소/최대값을 저장하는 일종의 자료 구조로 동적 계획법을 이용해 구현합니다.

 

먼저 용어부터 정의하겠습니다.

- sparseArray[i][j] : j번째 인덱스에서 시작하는  2^i 길이의 구간의 최소값. 

 

말로 하면 장황하니 그림으로 먼저 살펴보겠습니다.

 

 

sparseArray[0][j] 에서는 구간의 길이가 2^0이므로 1입니다. 그렇기 때문에 원본 배열과 동일한 모습입니다.

sparseArray[1][j] 은 구간의 길이가 2^1로 2입니다. 그래서 sparse[0][j]와 sparse[0][j+1]을 비교해 최소값을 가져옵니다.

sparseArray[2][j] 은 구간의 길이가 4입니다. 그래서 sparseArray[2][j] 와 sparseArray[2][j + 2^1] 중 최소값을 가져옵니다.

 

패턴을 보시면 감이 잡히시나요?

세그먼트 트리를 배우신 분이라면 이게 바텁업 버전의 세그먼트 트리와 같다는 사실을 눈치채실지도 모르겠네요.

j행의 값을 채워나갈 때, 구간을 이등분해 해당 구간에 필요한 값을 j-1행에서 찾고 그 값들을 비교해 최소값을 채운다는 것을 눈여겨보세요.

 

점화식은 다음과 같습니다.

sparseArray[i][j] = if(j + pow(2,i-1) < n) min(sparseArray[i-1][j], sparseArray[i-1][j + pow(2,i-1)]

n : 주어진 배열의 총 길이
pow(a,b) : a^b 

 

위의 점화식으로 희소배열을 초기화한다고 끝이 아닙니다. 세그먼트 트리와 동일하게 쿼리하는 과정이 필요합니다.

이 자료구조는 2의 배수의 사이즈의 구간을 가정하고 값들을 저장했습니다.  쿼리 구간이 홀수라면 어떻게 하나요?

 

이 점은 비트 연산을 이용하면 간편하게 구현 가능합니다.

예를 들어 쿼리 구간의 길이가 7이라면 4,2,1 이렇게 세 구간으로 나눠서 그 중에 최소값을 구해야 합니다.

7을 비트로 나타내면 111 이므로 이는 정확하게 구간이 4,2,1로 나눠질 수 있다는 걸 말해줍니다.

자세한 구현은 코드를 참고하세요.

 

import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.StringTokenizer;

/*
LCA2
https://www.acmicpc.net/problem/11438
 */
public class Main {
	static int n,k;
	static Node[] nodes;
	static int[][] sparseArr;
	static int[] first, indexToOrder, orderToIndex;
	static List<Integer> eulerPath = new ArrayList<>();
	public static void main(String[] args) throws IOException {
//		InputReader reader = new InputReader();
		InputReader reader = new InputReader("testcase.txt");
		n = reader.readInt();
		nodes = new Node[n];
		first = new int[n];
		indexToOrder = new int[n]; orderToIndex = new int[n];
		for (int i = 0; i < n; i++) { nodes[i] = new Node(i); }
		for (int i = 0; i < n-1; i++) {
			StringTokenizer st = new StringTokenizer(reader.readLine());
			int u = Integer.parseInt(st.nextToken())-1;
			int v = Integer.parseInt(st.nextToken())-1;
			nodes[u].children.add(v);
			nodes[v].children.add(u);
		}

		boolean[] visited = new boolean[n];
		eulerTravel(0, visited);

		k = log2(eulerPath.size());
		sparseArr = new int[k+1][eulerPath.size()];

		setSparseArray();
//		for (int i = 0; i <= k; i++) {
//			System.out.println(Arrays.toString(sparseArr[i]));
//		}
//		System.out.println();

		int m = reader.readInt();
		StringBuilder sb = new StringBuilder();
		for (int i = 0; i < m; i++) {
			StringTokenizer st = new StringTokenizer(reader.readLine());
			int u = Integer.parseInt(st.nextToken())-1;
			int v = Integer.parseInt(st.nextToken())-1;
			int firstU = first[u];
			int firstV = first[v];
			if(firstU > firstV) {
				int temp = firstU;
				firstU = firstV;
				firstV = temp;
			}
			sb.append(query(firstU, firstV) + "\n");
		}
		sb.deleteCharAt(sb.length()-1);
		System.out.println(sb.toString());
	}

	static int order = 0;
	private static void eulerTravel(int index, boolean[] visited) {
		visited[index] = true;

		Node node = nodes[index];
		indexToOrder[index] = order;
		orderToIndex[order] = index;
		eulerPath.add(index);
		first[index] = eulerPath.size()-1;
		order++;
		for (int i = 0; i < node.children.size(); i++) {
			int childId = node.children.get(i);
			if(!visited[childId]) {
				eulerTravel(childId, visited);
				eulerPath.add(index);
			}
		}
	}

	private static void setSparseArray() {
		for (int i = 0; i < eulerPath.size(); i++) {
			sparseArr[0][i] = indexToOrder[eulerPath.get(i)];
		}

		for (int i = 1; i <= k; i++) {
			for (int j = 0; j < eulerPath.size(); j++) {
				if(j+pow(2, i-1) < eulerPath.size()) {
					sparseArr[i][j] = Math.min(sparseArr[i-1][j],
							sparseArr[i-1][j+pow(2, i-1)]);
				}
			}
		}
	}

	private static int query(int start, int end) {
		int length = end-start+1;

		int k = log2(length);
		//2의 제곱수인 경우
//		if(pow(2,k) == length) {
//			return orderToIndex[sparseArr[k][start]]+1;
//		}
		//2의 제곱수가 아닌 경우
		int minOrder = Integer.MAX_VALUE;
		for (int i = 0; i <= k; i++) {
			if((length & (1 << i)) != 0) {
				minOrder = Math.min(sparseArr[i][start], minOrder);
				start += pow(2,i);
			}
		}
		return orderToIndex[minOrder]+1;
	}

	private static int pow(int base, int exp) {
		int ret = 1;
		while(exp != 0) {
			ret *= base;
			exp--;
		}
		return ret;
	}
	/*
	n보다 작거나 같은 2^k의 k를 반환
	 */
	private static int log2(int n) {
		if(n == 0) return 0;
		int cnt = 0;
		while(true) {
			n /= 2;
			cnt++;
			if(n == 0) break;
		}
		return cnt-1;
	}
}
class Node {
	int id;
	List<Integer> children = new ArrayList<>();
	public Node(int id) {
		this.id = id;
	}
}
class InputReader {
	private BufferedReader br;

	public InputReader() {
		br = new BufferedReader(new InputStreamReader(System.in));
	}

	public InputReader(String filepath) {
		try {
			br = new BufferedReader(new FileReader(filepath));
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
	}

	public List<Character> readLineIntoCharList() throws IOException {
		List<Character> l = new ArrayList<>();
		while(true) {
			int readVal = br.read();
			if(readVal == '\n' || readVal == -1) break;
			l.add((char)readVal);
		}
		return l;
	}

	public boolean ready() throws IOException {
		return br.ready();
	}

	public String readLine() throws IOException {
		return br.readLine();
	}
	public int readInt() throws IOException {
		return Integer.parseInt(readLine());
	}
	public Long readLong() throws IOException {
		return Long.parseLong(readLine());
	}
}
Comments