Number of Good Paths

hard tree union find sorting

Problem

You are given a tree with n nodes, each carrying a value. A "good path" starts and ends at nodes with the same value, with no intermediate node exceeding that value. Count all distinct good paths (including length-zero paths).

Inputvals = [1,3,2,1,3], edges = [[0,1],[0,2],[2,3],[2,4]]
Output6
Five single-node paths + (1↔1) over the 1-tree through node 2 — wait, here values 3 at node1 and node4 connect through node 2 (val 2) ≤ 3, contributing one path. Total 5 + 1 = 6.

def number_of_good_paths(vals, edges):
    n = len(vals)
    adj = [[] for _ in range(n)]
    for u, v in edges: adj[u].append(v); adj[v].append(u)
    parent = list(range(n))
    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x
    cnt = [1] * n
    order = sorted(range(n), key=lambda i: vals[i])
    active = [False] * n
    ans = 0
    for u in order:
        active[u] = True
        for v in adj[u]:
            if active[v] and vals[v] <= vals[u]:
                ru, rv = find(u), find(v)
                if ru != rv:
                    if vals[ru] == vals[rv] == vals[u]:
                        ans += cnt[ru] * cnt[rv]
                        cnt[ru] += cnt[rv]
                    elif vals[ru] == vals[u]:
                        pass
                    parent[rv] = ru
        ans = ans  # processed
    return ans + n
function numberOfGoodPaths(vals, edges) {
  const n = vals.length;
  const adj = Array.from({ length: n }, () => []);
  for (const [u, v] of edges) { adj[u].push(v); adj[v].push(u); }
  const parent = Array.from({ length: n }, (_, i) => i);
  const cnt = new Array(n).fill(1);
  function find(x) {
    while (parent[x] !== x) { parent[x] = parent[parent[x]]; x = parent[x]; }
    return x;
  }
  const order = [...Array(n).keys()].sort((a, b) => vals[a] - vals[b]);
  const active = new Array(n).fill(false);
  let ans = n;
  for (const u of order) {
    active[u] = true;
    for (const v of adj[u]) {
      if (!active[v] || vals[v] > vals[u]) continue;
      const ru = find(u), rv = find(v);
      if (ru === rv) continue;
      if (vals[ru] === vals[u] && vals[rv] === vals[u]) {
        ans += cnt[ru] * cnt[rv];
        cnt[ru] += cnt[rv];
      }
      parent[rv] = ru;
    }
  }
  return ans;
}
class Solution {
    int[] parent, cnt;
    public int numberOfGoodPaths(int[] vals, int[][] edges) {
        int n = vals.length;
        List<List<Integer>> adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
        for (int[] e : edges) { adj.get(e[0]).add(e[1]); adj.get(e[1]).add(e[0]); }
        parent = new int[n]; cnt = new int[n];
        for (int i = 0; i < n; i++) { parent[i] = i; cnt[i] = 1; }
        Integer[] order = new Integer[n];
        for (int i = 0; i < n; i++) order[i] = i;
        Arrays.sort(order, (a, b) -> vals[a] - vals[b]);
        boolean[] active = new boolean[n];
        int ans = n;
        for (int u : order) {
            active[u] = true;
            for (int v : adj.get(u)) {
                if (!active[v] || vals[v] > vals[u]) continue;
                int ru = find(u), rv = find(v);
                if (ru == rv) continue;
                if (vals[ru] == vals[u] && vals[rv] == vals[u]) {
                    ans += cnt[ru] * cnt[rv];
                    cnt[ru] += cnt[rv];
                }
                parent[rv] = ru;
            }
        }
        return ans;
    }
    int find(int x) {
        while (parent[x] != x) { parent[x] = parent[parent[x]]; x = parent[x]; }
        return x;
    }
}
int parent_[100005], cnt_[100005];
int find_(int x) {
    while (parent_[x] != x) { parent_[x] = parent_[parent_[x]]; x = parent_[x]; }
    return x;
}
int numberOfGoodPaths(vector<int>& vals, vector<vector<int>>& edges) {
    int n = vals.size();
    vector<vector<int>> adj(n);
    for (auto& e : edges) { adj[e[0]].push_back(e[1]); adj[e[1]].push_back(e[0]); }
    for (int i = 0; i < n; i++) { parent_[i] = i; cnt_[i] = 1; }
    vector<int> order(n);
    iota(order.begin(), order.end(), 0);
    sort(order.begin(), order.end(), [&](int a, int b) { return vals[a] < vals[b]; });
    vector<bool> active(n, false);
    int ans = n;
    for (int u : order) {
        active[u] = true;
        for (int v : adj[u]) {
            if (!active[v] || vals[v] > vals[u]) continue;
            int ru = find_(u), rv = find_(v);
            if (ru == rv) continue;
            if (vals[ru] == vals[u] && vals[rv] == vals[u]) {
                ans += cnt_[ru] * cnt_[rv];
                cnt_[ru] += cnt_[rv];
            }
            parent_[rv] = ru;
        }
    }
    return ans;
}
Time: O(n log n · α(n)) Space: O(n)