Sum of Distances in Tree

hard tree dfs rerooting

Problem

Given an undirected connected tree with n nodes, return an array ans where ans[i] is the sum of distances from node i to all other nodes.

Inputn = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output[8,12,6,10,10,10]
Each value is the sum of shortest distances from a node to every other node.

def sumOfDistancesInTree(n, edges):
    g = [[] for _ in range(n)]
    for a, b in edges:
        g[a].append(b); g[b].append(a)
    count = [1] * n
    ans = [0] * n
    def dfs1(u, p):
        for v in g[u]:
            if v == p: continue
            dfs1(v, u)
            count[u] += count[v]
            ans[u] += ans[v] + count[v]
    def dfs2(u, p):
        for v in g[u]:
            if v == p: continue
            ans[v] = ans[u] - count[v] + (n - count[v])
            dfs2(v, u)
    dfs1(0, -1)
    dfs2(0, -1)
    return ans
function sumOfDistancesInTree(n, edges) {
  const g = Array.from({length: n}, () => []);
  for (const [a, b] of edges) { g[a].push(b); g[b].push(a); }
  const count = new Array(n).fill(1);
  const ans = new Array(n).fill(0);
  const dfs1 = (u, p) => {
    for (const v of g[u]) if (v !== p) {
      dfs1(v, u);
      count[u] += count[v];
      ans[u] += ans[v] + count[v];
    }
  };
  const dfs2 = (u, p) => {
    for (const v of g[u]) if (v !== p) {
      ans[v] = ans[u] - count[v] + (n - count[v]);
      dfs2(v, u);
    }
  };
  dfs1(0, -1);
  dfs2(0, -1);
  return ans;
}
import java.util.*;
class Solution {
  int[] count, ans; List<List<Integer>> g; int N;
  public int[] sumOfDistancesInTree(int n, int[][] edges) {
    N = n; g = new ArrayList<>();
    for (int i = 0; i < n; i++) g.add(new ArrayList<>());
    for (int[] e : edges) { g.get(e[0]).add(e[1]); g.get(e[1]).add(e[0]); }
    count = new int[n]; Arrays.fill(count, 1);
    ans = new int[n];
    dfs1(0, -1); dfs2(0, -1);
    return ans;
  }
  void dfs1(int u, int p) {
    for (int v : g.get(u)) if (v != p) { dfs1(v, u); count[u] += count[v]; ans[u] += ans[v] + count[v]; }
  }
  void dfs2(int u, int p) {
    for (int v : g.get(u)) if (v != p) { ans[v] = ans[u] - count[v] + (N - count[v]); dfs2(v, u); }
  }
}
#include <bits/stdc++.h>
using namespace std;
class Solution {
public:
    vector<vector<int>> g; vector<int> count, ans; int N;
    void dfs1(int u, int p) {
        for (int v : g[u]) if (v != p) { dfs1(v, u); count[u] += count[v]; ans[u] += ans[v] + count[v]; }
    }
    void dfs2(int u, int p) {
        for (int v : g[u]) if (v != p) { ans[v] = ans[u] - count[v] + (N - count[v]); dfs2(v, u); }
    }
    vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
        N = n; g.assign(n, {}); count.assign(n, 1); ans.assign(n, 0);
        for (auto &e : edges) { g[e[0]].push_back(e[1]); g[e[1]].push_back(e[0]); }
        dfs1(0, -1); dfs2(0, -1);
        return ans;
    }
};
Time: O(n) Space: O(n)