Minimize Malware Spread

hard graph union-find components

Problem

You are given an adjacency matrix graph where graph[i][j] == 1 means nodes i and j are directly connected. Malware spreads through every connection, so any node in a connected component that contains an infected node becomes infected. Given a list initial of initially infected nodes, remove exactly one node from initial to minimize the final number of infected nodes M(initial). Return the node to remove; if several nodes give the same result, return the one with the smallest index.

Inputgraph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0, 1]
Output0
Nodes 0 and 1 form one component holding both infected nodes, so removing either still infects it; node 2 stays clean. We break the tie by returning the smaller index, 0.

def min_malware_spread(graph, initial):
    n = len(graph)
    parent = list(range(n))
    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x
    for i in range(n):
        for j in range(i + 1, n):
            if graph[i][j]:
                parent[find(i)] = find(j)
    size = {}
    for i in range(n):
        r = find(i)
        size[r] = size.get(r, 0) + 1
    count = {}
    for node in initial:
        r = find(node)
        count[r] = count.get(r, 0) + 1
    best, best_save = min(initial), -1
    for node in sorted(initial):
        r = find(node)
        if count[r] == 1 and size[r] > best_save:
            best_save = size[r]
            best = node
    return best
function minMalwareSpread(graph, initial) {
  const n = graph.length;
  const parent = Array.from({ length: n }, (_, i) => i);
  function find(x) {
    while (parent[x] !== x) { parent[x] = parent[parent[x]]; x = parent[x]; }
    return x;
  }
  for (let i = 0; i < n; i++)
    for (let j = i + 1; j < n; j++)
      if (graph[i][j]) parent[find(i)] = find(j);
  const size = {}, count = {};
  for (let i = 0; i < n; i++) { const r = find(i); size[r] = (size[r] || 0) + 1; }
  for (const node of initial) { const r = find(node); count[r] = (count[r] || 0) + 1; }
  let best = Math.min(...initial), bestSave = -1;
  for (const node of [...initial].sort((a, b) => a - b)) {
    const r = find(node);
    if (count[r] === 1 && size[r] > bestSave) { bestSave = size[r]; best = node; }
  }
  return best;
}
class Solution {
    int[] parent;
    int find(int x) {
        while (parent[x] != x) { parent[x] = parent[parent[x]]; x = parent[x]; }
        return x;
    }
    public int minMalwareSpread(int[][] graph, int[] initial) {
        int n = graph.length;
        parent = new int[n];
        for (int i = 0; i < n; i++) parent[i] = i;
        for (int i = 0; i < n; i++)
            for (int j = i + 1; j < n; j++)
                if (graph[i][j] == 1) parent[find(i)] = find(j);
        int[] size = new int[n], count = new int[n];
        for (int i = 0; i < n; i++) size[find(i)]++;
        for (int node : initial) count[find(node)]++;
        java.util.Arrays.sort(initial);
        int best = initial[0], bestSave = -1;
        for (int node : initial) {
            int r = find(node);
            if (count[r] == 1 && size[r] > bestSave) { bestSave = size[r]; best = node; }
        }
        return best;
    }
}
class Solution {
    vector<int> parent;
    int find(int x) {
        while (parent[x] != x) { parent[x] = parent[parent[x]]; x = parent[x]; }
        return x;
    }
public:
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        parent.resize(n);
        for (int i = 0; i < n; i++) parent[i] = i;
        for (int i = 0; i < n; i++)
            for (int j = i + 1; j < n; j++)
                if (graph[i][j]) parent[find(i)] = find(j);
        vector<int> size(n, 0), count(n, 0);
        for (int i = 0; i < n; i++) size[find(i)]++;
        for (int node : initial) count[find(node)]++;
        sort(initial.begin(), initial.end());
        int best = initial[0], bestSave = -1;
        for (int node : initial) {
            int r = find(node);
            if (count[r] == 1 && size[r] > bestSave) { bestSave = size[r]; best = node; }
        }
        return best;
    }
};
Time: O(n² · α(n)) Space: O(n)