技術メモ

神奈川在住のITエンジニアの備忘録。おもにプログラミングやネットワーク技術について、学んだことを自分の中で整理するためにゆるゆると書いています。ちゃんと検証できていない部分もあるのでご参考程度となりますが、誰かのお役に立てれば幸いです。

Union-Find 木

leetCode の以下の問題を Union-Find 木を使って解いた。
https://leetcode.com/problems/smallest-string-with-swaps/

今後の参考のために、自分の書いたコードを載せておく。

import java.util.*;

public class SmallestString {
    public String smallestStringWithSwaps(String s, List<List<Integer>> pairs) {
        int len = s.length();
        UnionFind uf = new UnionFind(len);

        // Grouping する
        for (List<Integer> pair : pairs) {
            int x = pair.get(0);
            int y = pair.get(1);
            if (!uf.isSame(x, y)) {
                uf.unite(x, y);
            }
        }

        // root の index ⇒ members となる、char の Map を作成する。
        Map<Integer, Queue<Character>> rootToMembers = new HashMap<>();
        for (int i = 0; i < len; i++) {
            int root = uf.findRoot(i);
            if (rootToMembers.containsKey(root)) {
                rootToMembers.get(root).add(s.charAt(i));
            }
            else {
                PriorityQueue<Character> q = new PriorityQueue<>();
                q.add(s.charAt(i));
                rootToMembers.put(root, q);
            }
        }

        char[] swapped = new char[len];
        for (int i = 0; i < len; i++) {
            Queue<Character> members = rootToMembers.get(uf.findRoot(i));
            if (!members.isEmpty()) {
                swapped[i] = members.poll();
            }
        }
        return new String(swapped);
    }

    private static class UnionFind {
        int[] roots;
        int[] ranks;

        UnionFind(int size) {
            roots = new int[size];

            // 初期値として i の root には i (自分) を入れる。
            for (int i = 0; i < size; i++) {
                roots[i] = i;
            }

            ranks = new int[size];
        }

        void unite(int x, int y) {
            int rootX = findRoot(x);
            int rootY = findRoot(y);

            if (ranks[x] > ranks[y]) {
                // x を y の所属するグループに入れる。
                roots[rootY] = rootX;
                ranks[rootX]++;
            }
            else {
                // y を x の所属するグループに入れる。
                roots[rootX] = rootY;
                ranks[rootY]++;
            }
        }

        // x の root を返す。(再帰バージョン)
        int findRoot(int x) {
            if (x == roots[x]) {
                return x;
            }
            else {
                roots[x] = findRoot(roots[x]);
                return roots[x];
            }
        }

        // x の root を返す。(繰り返しバージョン)
        int findRootItr(int x) {
            while (x != roots[x]) {
                // root を辿るついでに経路圧縮
                roots[x] = roots[roots[x]];
                x = roots[x];
            }
            return roots[x];
        }

        // root が同じかどうか、つまり、同じグループかどうかを返す。
        boolean isSame(int x, int y) {
            return findRoot(x) == findRoot(y);
        }
    }
}

ポイントは、コードの中のコメントに書いている。
findRoot メソッドは、再帰と繰り返しの2バージョンを書いたみた。この場合、再帰の方が分かりやすいと思う。