Count Nodes With the Highest Score

medium tree dfs binary tree

Problem

A binary tree has n nodes labeled 0..n-1. Removing a node yields up to three disconnected components (its two subtrees and the rest). The score of a node is the product of those component sizes. Return how many nodes share the maximum score.

Inputparents = [-1, 2, 0, 2, 0]
Output3
Three nodes have the maximum score under the product rule.

def count_highest_score_nodes(parents):
    n = len(parents)
    children = [[] for _ in range(n)]
    for v, p in enumerate(parents):
        if p != -1:
            children[p].append(v)
    best, freq = 0, 0
    def dfs(u):
        nonlocal best, freq
        size, score = 1, 1
        for c in children[u]:
            sub = dfs(c)
            size += sub
            score *= sub
        rest = n - size
        if rest > 0: score *= rest
        if score > best: best, freq = score, 1
        elif score == best: freq += 1
        return size
    dfs(parents.index(-1))
    return freq
function countHighestScoreNodes(parents) {
  const n = parents.length;
  const children = Array.from({ length: n }, () => []);
  let root = 0;
  for (let v = 0; v < n; v++) {
    if (parents[v] === -1) root = v;
    else children[parents[v]].push(v);
  }
  let best = 0n, freq = 0;
  const dfs = (u) => {
    let size = 1, score = 1n;
    for (const c of children[u]) {
      const sub = dfs(c);
      size += sub;
      score *= BigInt(sub);
    }
    const rest = n - size;
    if (rest > 0) score *= BigInt(rest);
    if (score > best) { best = score; freq = 1; }
    else if (score === best) freq++;
    return size;
  };
  dfs(root);
  return freq;
}
class Solution {
    long best = 0;
    int freq = 0;
    int n;
    java.util.List<java.util.List<Integer>> children;
    public int countHighestScoreNodes(int[] parents) {
        n = parents.length;
        children = new java.util.ArrayList<>();
        for (int i = 0; i < n; i++) children.add(new java.util.ArrayList<>());
        int root = 0;
        for (int v = 0; v < n; v++) {
            if (parents[v] == -1) root = v;
            else children.get(parents[v]).add(v);
        }
        dfs(root);
        return freq;
    }
    private int dfs(int u) {
        int size = 1;
        long score = 1L;
        for (int c : children.get(u)) {
            int sub = dfs(c);
            size += sub;
            score *= sub;
        }
        int rest = n - size;
        if (rest > 0) score *= rest;
        if (score > best) { best = score; freq = 1; }
        else if (score == best) freq++;
        return size;
    }
}
int countHighestScoreNodes(vector<int>& parents) {
    int n = parents.size();
    vector<vector<int>> children(n);
    int root = 0;
    for (int v = 0; v < n; v++) {
        if (parents[v] == -1) root = v;
        else children[parents[v]].push_back(v);
    }
    long long best = 0;
    int freq = 0;
    function<int(int)> dfs = [&](int u) {
        int size = 1;
        long long score = 1;
        for (int c : children[u]) {
            int sub = dfs(c);
            size += sub;
            score *= sub;
        }
        int rest = n - size;
        if (rest > 0) score *= rest;
        if (score > best) { best = score; freq = 1; }
        else if (score == best) freq++;
        return size;
    };
    dfs(root);
    return freq;
}
Time: O(n) Space: O(n)