Tree Diameter

medium tree bfs graph

Problem

Given an undirected tree with n nodes labeled 0 to n − 1, described by a list of edges, return the diameter of the tree: the number of edges in the longest path between any two nodes.

Inputedges = [[0,1],[1,2],[1,3],[3,4]]
Output3
The longest path runs 0 → 1 → 3 → 4 (or 2 → 1 → 3 → 4), covering 3 edges.

from collections import deque

def tree_diameter(edges):
    n = len(edges) + 1
    g = [[] for _ in range(n)]
    for a, b in edges:
        g[a].append(b)
        g[b].append(a)

    def bfs(src):
        dist = [-1] * n
        dist[src] = 0
        q = deque([src])
        far = src
        while q:
            u = q.popleft()
            for v in g[u]:
                if dist[v] == -1:
                    dist[v] = dist[u] + 1
                    if dist[v] > dist[far]:
                        far = v
                    q.append(v)
        return far, dist[far]

    a, _ = bfs(0)      # farthest node from an arbitrary start
    b, d = bfs(a)      # farthest node from a -> path length is the diameter
    return d
function treeDiameter(edges) {
  const n = edges.length + 1;
  const g = Array.from({ length: n }, () => []);
  for (const [a, b] of edges) { g[a].push(b); g[b].push(a); }

  function bfs(src) {
    const dist = new Array(n).fill(-1);
    dist[src] = 0;
    const q = [src];
    let far = src;
    for (let h = 0; h < q.length; h++) {
      const u = q[h];
      for (const v of g[u]) {
        if (dist[v] === -1) {
          dist[v] = dist[u] + 1;
          if (dist[v] > dist[far]) far = v;
          q.push(v);
        }
      }
    }
    return [far, dist[far]];
  }

  const [a] = bfs(0);      // farthest node from an arbitrary start
  const [, d] = bfs(a);    // farthest from a -> diameter
  return d;
}
class Solution {
    int n;
    List<List<Integer>> g;

    public int treeDiameter(int[][] edges) {
        n = edges.length + 1;
        g = new ArrayList<>();
        for (int i = 0; i < n; i++) g.add(new ArrayList<>());
        for (int[] e : edges) { g.get(e[0]).add(e[1]); g.get(e[1]).add(e[0]); }
        int a = bfs(0)[0];      // farthest from an arbitrary start
        return bfs(a)[1];       // farthest from a -> diameter
    }

    int[] bfs(int src) {
        int[] dist = new int[n];
        Arrays.fill(dist, -1);
        dist[src] = 0;
        Deque<Integer> q = new ArrayDeque<>();
        q.add(src);
        int far = src;
        while (!q.isEmpty()) {
            int u = q.poll();
            for (int v : g.get(u)) {
                if (dist[v] == -1) {
                    dist[v] = dist[u] + 1;
                    if (dist[v] > dist[far]) far = v;
                    q.add(v);
                }
            }
        }
        return new int[]{ far, dist[far] };
    }
}
class Solution {
public:
    int treeDiameter(vector<vector<int>>& edges) {
        int n = edges.size() + 1;
        vector<vector<int>> g(n);
        for (auto& e : edges) { g[e[0]].push_back(e[1]); g[e[1]].push_back(e[0]); }
        auto bfs = [&](int src) {
            vector<int> dist(n, -1);
            dist[src] = 0;
            queue<int> q; q.push(src);
            int far = src;
            while (!q.empty()) {
                int u = q.front(); q.pop();
                for (int v : g[u]) {
                    if (dist[v] == -1) {
                        dist[v] = dist[u] + 1;
                        if (dist[v] > dist[far]) far = v;
                        q.push(v);
                    }
                }
            }
            return pair<int,int>{ far, dist[far] };
        };
        int a = bfs(0).first;   // farthest from an arbitrary start
        return bfs(a).second;   // farthest from a -> diameter
    }
};
Time: O(n) Space: O(n)