Minimize Malware Spread II

hard graph union find dfs

Problem

You are given an undirected network as an adjacency matrix graph where graph[i][j] = 1 means nodes i and j are connected. Some nodes are initially infected, listed in initial. Malware spreads between any two directly connected nodes. After the spread finishes, M(initial) is the number of infected nodes. You must completely remove one node from initial (deleting it and all its edges from the graph) to minimize the final number of infected nodes. Return the node with the smallest index that achieves this minimum.

Inputgraph = [[1,1,0],[1,1,1],[0,1,1]], initial = [0,1]
Output0
Node 2 (clean) is reachable from both infected sources 0 and 1, so it is not saved by removing just one of them. But removing node 1 also disconnects nothing extra it uniquely guards beyond itself, while removing node 0 leaves the same spread; ties break to the smallest index, so the answer is 0.

def min_malware_spread(graph, initial):
    n = len(graph)
    infected = set(initial)
    saved = {}                       # node -> nodes it would save
    for v in initial:
        seen = set()                 # clean nodes reached without other sources
        stack = [u for u in range(n)
                 if graph[v][u] and u not in infected]
        while stack:                 # DFS over clean nodes only
            x = stack.pop()
            if x in seen:
                continue
            seen.add(x)
            for y in range(n):
                if graph[x][y] and y not in infected and y not in seen:
                    stack.append(y)
        for c in seen:               # tally one vote per source
            saved[c] = saved.get(c, 0) + (1, v)[0]
    # a clean node is rescued only if exactly one source reaches it
    reach = {}
    for v in initial:
        seen, stack = set(), [u for u in range(n)
                              if graph[v][u] and u not in infected]
        while stack:
            x = stack.pop()
            if x in seen: continue
            seen.add(x)
            for y in range(n):
                if graph[x][y] and y not in infected and y not in seen:
                    stack.append(y)
        for c in seen:
            reach.setdefault(c, []).append(v)
    best, best_save = min(initial), -1
    for v in initial:               # count uniquely-reached clean nodes
        cnt = sum(1 for c, srcs in reach.items()
                  if srcs == [v])
        if cnt > best_save or (cnt == best_save and v < best):
            best, best_save = v, cnt
    return best
function minMalwareSpread(graph, initial) {
  const n = graph.length;
  const infected = new Set(initial);
  const reach = new Map();           // clean node -> list of sources
  for (const v of initial) {
    const seen = new Set();
    const stack = [];
    for (let u = 0; u < n; u++)
      if (graph[v][u] && !infected.has(u)) stack.push(u);
    while (stack.length) {           // DFS over clean nodes only
      const x = stack.pop();
      if (seen.has(x)) continue;
      seen.add(x);
      for (let y = 0; y < n; y++)
        if (graph[x][y] && !infected.has(y) && !seen.has(y)) stack.push(y);
    }
    for (const c of seen) {
      if (!reach.has(c)) reach.set(c, []);
      reach.get(c).push(v);
    }
  }
  let best = Math.min(...initial), bestSave = -1;
  for (const v of initial) {         // count uniquely-reached clean nodes
    let cnt = 0;
    for (const srcs of reach.values())
      if (srcs.length === 1 && srcs[0] === v) cnt++;
    if (cnt > bestSave || (cnt === bestSave && v < best)) {
      best = v; bestSave = cnt;
    }
  }
  return best;
}
class Solution {
    public int minMalwareSpread(int[][] graph, int[] initial) {
        int n = graph.length;
        Set<Integer> infected = new HashSet<>();
        for (int v : initial) infected.add(v);
        Map<Integer, List<Integer>> reach = new HashMap<>();
        for (int v : initial) {
            Set<Integer> seen = new HashSet<>();
            Deque<Integer> stack = new ArrayDeque<>();
            for (int u = 0; u < n; u++)
                if (graph[v][u] == 1 && !infected.contains(u)) stack.push(u);
            while (!stack.isEmpty()) {          // DFS over clean nodes
                int x = stack.pop();
                if (!seen.add(x)) continue;
                for (int y = 0; y < n; y++)
                    if (graph[x][y] == 1 && !infected.contains(y) && !seen.contains(y))
                        stack.push(y);
            }
            for (int c : seen)
                reach.computeIfAbsent(c, k -> new ArrayList<>()).add(v);
        }
        int best = Integer.MAX_VALUE, bestSave = -1;
        for (int v : initial) best = Math.min(best, v);
        for (int v : initial) {                 // uniquely-reached clean nodes
            int cnt = 0;
            for (List<Integer> srcs : reach.values())
                if (srcs.size() == 1 && srcs.get(0) == v) cnt++;
            if (cnt > bestSave || (cnt == bestSave && v < best)) {
                best = v; bestSave = cnt;
            }
        }
        return best;
    }
}
int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
    int n = graph.size();
    set<int> infected(initial.begin(), initial.end());
    map<int, vector<int>> reach;          // clean node -> sources
    for (int v : initial) {
        set<int> seen;
        vector<int> stack;
        for (int u = 0; u < n; u++)
            if (graph[v][u] && !infected.count(u)) stack.push_back(u);
        while (!stack.empty()) {           // DFS over clean nodes only
            int x = stack.back(); stack.pop_back();
            if (seen.count(x)) continue;
            seen.insert(x);
            for (int y = 0; y < n; y++)
                if (graph[x][y] && !infected.count(y) && !seen.count(y))
                    stack.push_back(y);
        }
        for (int c : seen) reach[c].push_back(v);
    }
    int best = *min_element(initial.begin(), initial.end()), bestSave = -1;
    for (int v : initial) {                // uniquely-reached clean nodes
        int cnt = 0;
        for (auto& kv : reach)
            if (kv.second.size() == 1 && kv.second[0] == v) cnt++;
        if (cnt > bestSave || (cnt == bestSave && v < best)) {
            best = v; bestSave = cnt;
        }
    }
    return best;
}
Time: O(|initial| · n²) Space: O(n)