Remove Max Number of Edges to Keep Graph Fully Traversable

hard union find greedy

Problem

An undirected graph has n nodes (labeled 1..n) and edges of three types: type 1 edges only Alice may walk, type 2 edges only Bob may walk, and type 3 edges both may walk. The graph is fully traversable when Alice can reach every node using types 1 and 3, and Bob can reach every node using types 2 and 3.

Delete as many edges as possible while keeping the graph fully traversable for both, and return that maximum count — or −1 if the graph is not fully traversable even with every edge kept.

Inputn = 4, edges = [[3,1,2],[3,2,3],[1,1,3],[1,2,4],[1,1,2],[2,3,4]]
Output2
Deleting [1,1,3] and [1,1,2] still lets both Alice and Bob reach all 4 nodes; no third edge can go.

def max_num_edges_to_remove(n, edges):
    def find(p, x):
        while p[x] != x:
            p[x] = x = p[p[x]]
        return x

    def union(p, a, b):
        a, b = find(p, a), find(p, b)
        if a == b:
            return False
        p[a] = b
        return True

    alice, bob = list(range(n + 1)), list(range(n + 1))
    merges, removed = 0, 0
    for t, a, b in sorted(edges, reverse=True):
        if t == 3:
            if union(alice, a, b):
                union(bob, a, b)
                merges += 2
            else:
                removed += 1
        elif union(alice if t == 1 else bob, a, b):
            merges += 1
        else:
            removed += 1
    return removed if merges == 2 * (n - 1) else -1
function maxNumEdgesToRemove(n, edges) {
  const find = (p, x) => {
    while (p[x] !== x) x = p[x] = p[p[x]];
    return x;
  };
  const union = (p, a, b) => {
    a = find(p, a); b = find(p, b);
    if (a === b) return false;
    p[a] = b;
    return true;
  };
  const alice = Array.from({ length: n + 1 }, (_, i) => i);
  const bob = alice.slice();
  let merges = 0, removed = 0;
  for (const [t, a, b] of [...edges].sort((x, y) => y[0] - x[0])) {
    if (t === 3) {
      if (union(alice, a, b)) { union(bob, a, b); merges += 2; }
      else removed++;
    } else if (union(t === 1 ? alice : bob, a, b)) {
      merges++;
    } else {
      removed++;
    }
  }
  return merges === 2 * (n - 1) ? removed : -1;
}
class Solution {
    public int maxNumEdgesToRemove(int n, int[][] edges) {
        int[] alice = new int[n + 1], bob = new int[n + 1];
        for (int i = 0; i <= n; i++) { alice[i] = i; bob[i] = i; }
        Arrays.sort(edges, (x, y) -> y[0] - x[0]);
        int merges = 0, removed = 0;
        for (int[] e : edges) {
            if (e[0] == 3) {
                if (union(alice, e[1], e[2])) { union(bob, e[1], e[2]); merges += 2; }
                else removed++;
            } else if (union(e[0] == 1 ? alice : bob, e[1], e[2])) {
                merges++;
            } else {
                removed++;
            }
        }
        return merges == 2 * (n - 1) ? removed : -1;
    }

    private int find(int[] p, int x) {
        while (p[x] != x) x = p[x] = p[p[x]];
        return x;
    }

    private boolean union(int[] p, int a, int b) {
        a = find(p, a); b = find(p, b);
        if (a == b) return false;
        p[a] = b;
        return true;
    }
}
class Solution {
public:
    int maxNumEdgesToRemove(int n, vector<vector<int>>& edges) {
        vector<int> alice(n + 1), bob(n + 1);
        iota(alice.begin(), alice.end(), 0);
        iota(bob.begin(), bob.end(), 0);
        sort(edges.begin(), edges.end(), [](auto& x, auto& y) { return x[0] > y[0]; });
        int merges = 0, removed = 0;
        for (auto& e : edges) {
            if (e[0] == 3) {
                if (unite(alice, e[1], e[2])) { unite(bob, e[1], e[2]); merges += 2; }
                else removed++;
            } else if (unite(e[0] == 1 ? alice : bob, e[1], e[2])) {
                merges++;
            } else {
                removed++;
            }
        }
        return merges == 2 * (n - 1) ? removed : -1;
    }

    int find(vector<int>& p, int x) {
        while (p[x] != x) x = p[x] = p[p[x]];
        return x;
    }

    bool unite(vector<int>& p, int a, int b) {
        a = find(p, a); b = find(p, b);
        if (a == b) return false;
        p[a] = b;
        return true;
    }
};
Time: O(E · α(n)) Space: O(n)