Sum of Distances in Tree
Problem
Given an undirected connected tree with n nodes, return an array ans where ans[i] is the sum of distances from node i to all other nodes.
n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]][8,12,6,10,10,10]def sumOfDistancesInTree(n, edges):
g = [[] for _ in range(n)]
for a, b in edges:
g[a].append(b); g[b].append(a)
count = [1] * n
ans = [0] * n
def dfs1(u, p):
for v in g[u]:
if v == p: continue
dfs1(v, u)
count[u] += count[v]
ans[u] += ans[v] + count[v]
def dfs2(u, p):
for v in g[u]:
if v == p: continue
ans[v] = ans[u] - count[v] + (n - count[v])
dfs2(v, u)
dfs1(0, -1)
dfs2(0, -1)
return ans
function sumOfDistancesInTree(n, edges) {
const g = Array.from({length: n}, () => []);
for (const [a, b] of edges) { g[a].push(b); g[b].push(a); }
const count = new Array(n).fill(1);
const ans = new Array(n).fill(0);
const dfs1 = (u, p) => {
for (const v of g[u]) if (v !== p) {
dfs1(v, u);
count[u] += count[v];
ans[u] += ans[v] + count[v];
}
};
const dfs2 = (u, p) => {
for (const v of g[u]) if (v !== p) {
ans[v] = ans[u] - count[v] + (n - count[v]);
dfs2(v, u);
}
};
dfs1(0, -1);
dfs2(0, -1);
return ans;
}
import java.util.*;
class Solution {
int[] count, ans; List<List<Integer>> g; int N;
public int[] sumOfDistancesInTree(int n, int[][] edges) {
N = n; g = new ArrayList<>();
for (int i = 0; i < n; i++) g.add(new ArrayList<>());
for (int[] e : edges) { g.get(e[0]).add(e[1]); g.get(e[1]).add(e[0]); }
count = new int[n]; Arrays.fill(count, 1);
ans = new int[n];
dfs1(0, -1); dfs2(0, -1);
return ans;
}
void dfs1(int u, int p) {
for (int v : g.get(u)) if (v != p) { dfs1(v, u); count[u] += count[v]; ans[u] += ans[v] + count[v]; }
}
void dfs2(int u, int p) {
for (int v : g.get(u)) if (v != p) { ans[v] = ans[u] - count[v] + (N - count[v]); dfs2(v, u); }
}
}
#include <bits/stdc++.h>
using namespace std;
class Solution {
public:
vector<vector<int>> g; vector<int> count, ans; int N;
void dfs1(int u, int p) {
for (int v : g[u]) if (v != p) { dfs1(v, u); count[u] += count[v]; ans[u] += ans[v] + count[v]; }
}
void dfs2(int u, int p) {
for (int v : g[u]) if (v != p) { ans[v] = ans[u] - count[v] + (N - count[v]); dfs2(v, u); }
}
vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
N = n; g.assign(n, {}); count.assign(n, 1); ans.assign(n, 0);
for (auto &e : edges) { g[e[0]].push_back(e[1]); g[e[1]].push_back(e[0]); }
dfs1(0, -1); dfs2(0, -1);
return ans;
}
};
Explanation
Computing the distance sum for each node separately would be slow. Instead we root the tree at node 0 and use a clever two-pass DFS (this is called rerooting) so each answer is built from its parent's answer.
The first pass, dfs1, works bottom-up. For each node it records count[u], the number of nodes in its subtree, and ans[u], the sum of distances to just the nodes in its subtree. A child contributes ans[v] + count[v] because every node under v is now one edge farther from u.
The second pass, dfs2, works top-down and reroots the answer. When we move the "root" from u to its child v, the count[v] nodes on v's side each get one closer, and the remaining n - count[v] nodes each get one farther. So ans[v] = ans[u] - count[v] + (n - count[v]).
Example: with n = 6 and edges forming a star around node 2, dfs1 fills in the root's total, then dfs2 slides that total outward to every other node, producing [8,12,6,10,10,10].
Because each pass touches every node and edge once, the whole algorithm is O(n) instead of O(n^2).