Count Unreachable Pairs of Nodes in an Undirected Graph

medium union find components

Problem

You have an undirected graph with n nodes labeled 0 to n − 1 and a list of edges. Count the unordered pairs of distinct nodes that cannot reach each other through any path. Since n can reach 10⁵, the count can exceed a 32-bit integer.

Inputn = 7, edges = [[0,2],[0,5],[2,4],[1,6],[5,4]]
Output14
The components are {0,2,4,5}, {1,6} and {3}; pairs across them: 4·2 + 4·1 + 2·1 = 14.

def count_pairs(n, edges):
    parent = list(range(n))
    size = [1] * n

    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    for a, b in edges:
        ra, rb = find(a), find(b)
        if ra != rb:
            if size[ra] < size[rb]:
                ra, rb = rb, ra
            parent[rb] = ra
            size[ra] += size[rb]

    ans, remaining = 0, n
    for v in range(n):
        if find(v) == v:
            remaining -= size[v]
            ans += size[v] * remaining
    return ans
function countPairs(n, edges) {
  const parent = Array.from({ length: n }, (_, i) => i);
  const size = new Array(n).fill(1);
  const find = (x) => {
    while (parent[x] !== x) {
      parent[x] = parent[parent[x]];
      x = parent[x];
    }
    return x;
  };
  for (const [a, b] of edges) {
    let ra = find(a), rb = find(b);
    if (ra !== rb) {
      if (size[ra] < size[rb]) [ra, rb] = [rb, ra];
      parent[rb] = ra;
      size[ra] += size[rb];
    }
  }
  let ans = 0, remaining = n;
  for (let v = 0; v < n; v++) {
    if (find(v) === v) {
      remaining -= size[v];
      ans += size[v] * remaining;
    }
  }
  return ans;
}
long countPairs(int n, int[][] edges) {
    int[] parent = new int[n], size = new int[n];
    for (int i = 0; i < n; i++) { parent[i] = i; size[i] = 1; }
    for (int[] e : edges) {
        int ra = find(parent, e[0]), rb = find(parent, e[1]);
        if (ra != rb) {
            if (size[ra] < size[rb]) { int t = ra; ra = rb; rb = t; }
            parent[rb] = ra;
            size[ra] += size[rb];
        }
    }
    long ans = 0, remaining = n;
    for (int v = 0; v < n; v++)
        if (find(parent, v) == v) {
            remaining -= size[v];
            ans += (long) size[v] * remaining;
        }
    return ans;
}

int find(int[] parent, int x) {
    while (parent[x] != x) {
        parent[x] = parent[parent[x]];
        x = parent[x];
    }
    return x;
}
long long countPairs(int n, vector<vector<int>>& edges) {
    vector<int> parent(n), sz(n, 1);
    iota(parent.begin(), parent.end(), 0);
    function<int(int)> find = [&](int x) {
        while (parent[x] != x) {
            parent[x] = parent[parent[x]];
            x = parent[x];
        }
        return x;
    };
    for (auto& e : edges) {
        int ra = find(e[0]), rb = find(e[1]);
        if (ra != rb) {
            if (sz[ra] < sz[rb]) swap(ra, rb);
            parent[rb] = ra;
            sz[ra] += sz[rb];
        }
    }
    long long ans = 0, remaining = n;
    for (int v = 0; v < n; v++)
        if (find(v) == v) {
            remaining -= sz[v];
            ans += (long long) sz[v] * remaining;
        }
    return ans;
}
Time: O((n + e) · α(n)) Space: O(n)