Count Nodes With the Highest Score
Problem
A binary tree has n nodes labeled 0..n-1. Removing a node yields up to three disconnected components (its two subtrees and the rest). The score of a node is the product of those component sizes. Return how many nodes share the maximum score.
parents = [-1, 2, 0, 2, 0]3def count_highest_score_nodes(parents):
n = len(parents)
children = [[] for _ in range(n)]
for v, p in enumerate(parents):
if p != -1:
children[p].append(v)
best, freq = 0, 0
def dfs(u):
nonlocal best, freq
size, score = 1, 1
for c in children[u]:
sub = dfs(c)
size += sub
score *= sub
rest = n - size
if rest > 0: score *= rest
if score > best: best, freq = score, 1
elif score == best: freq += 1
return size
dfs(parents.index(-1))
return freq
function countHighestScoreNodes(parents) {
const n = parents.length;
const children = Array.from({ length: n }, () => []);
let root = 0;
for (let v = 0; v < n; v++) {
if (parents[v] === -1) root = v;
else children[parents[v]].push(v);
}
let best = 0n, freq = 0;
const dfs = (u) => {
let size = 1, score = 1n;
for (const c of children[u]) {
const sub = dfs(c);
size += sub;
score *= BigInt(sub);
}
const rest = n - size;
if (rest > 0) score *= BigInt(rest);
if (score > best) { best = score; freq = 1; }
else if (score === best) freq++;
return size;
};
dfs(root);
return freq;
}
class Solution {
long best = 0;
int freq = 0;
int n;
java.util.List<java.util.List<Integer>> children;
public int countHighestScoreNodes(int[] parents) {
n = parents.length;
children = new java.util.ArrayList<>();
for (int i = 0; i < n; i++) children.add(new java.util.ArrayList<>());
int root = 0;
for (int v = 0; v < n; v++) {
if (parents[v] == -1) root = v;
else children.get(parents[v]).add(v);
}
dfs(root);
return freq;
}
private int dfs(int u) {
int size = 1;
long score = 1L;
for (int c : children.get(u)) {
int sub = dfs(c);
size += sub;
score *= sub;
}
int rest = n - size;
if (rest > 0) score *= rest;
if (score > best) { best = score; freq = 1; }
else if (score == best) freq++;
return size;
}
}
int countHighestScoreNodes(vector<int>& parents) {
int n = parents.size();
vector<vector<int>> children(n);
int root = 0;
for (int v = 0; v < n; v++) {
if (parents[v] == -1) root = v;
else children[parents[v]].push_back(v);
}
long long best = 0;
int freq = 0;
function<int(int)> dfs = [&](int u) {
int size = 1;
long long score = 1;
for (int c : children[u]) {
int sub = dfs(c);
size += sub;
score *= sub;
}
int rest = n - size;
if (rest > 0) score *= rest;
if (score > best) { best = score; freq = 1; }
else if (score == best) freq++;
return size;
};
dfs(root);
return freq;
}
Explanation
If you imagine deleting one node, the tree falls apart into pieces: its left subtree, its right subtree, and everything else above it. The score is just the sizes of those pieces multiplied together. The job is to find how many nodes share the biggest score.
The clever part is that we never actually delete anything. A single DFS computes the size of every subtree. Once you know a node's subtree size, the "rest of the tree" piece is simply n - size, where n is the total number of nodes.
In dfs(u) we add up the sizes returned by each child to get this node's own subtree size, and at the same time multiply those child sizes together into score. After the children, we fold in the upward piece with score *= rest when rest > 0.
We keep a running best score and a count freq. If the new score beats best, we reset the count to 1; if it ties, we bump freq. At the end freq is the answer.
Example: parents = [-1, 2, 0, 2, 0]. We build the children lists, run one DFS from the root, and three different nodes turn out to produce the same maximum product, so the answer is 3. Because every node is visited once, this runs in linear time.