Find Leaves of Binary Tree
Problem
Repeatedly remove leaves of a binary tree and collect them into groups, in the order they're removed.
root = [1,2,3,4,5][[4,5,3],[2],[1]]def find_leaves(root):
out = []
def dfs(node):
if not node: return -1
d = max(dfs(node.left), dfs(node.right)) + 1
if d == len(out): out.append([])
out[d].append(node.val)
return d
dfs(root)
return out
function findLeaves(root) {
const out = [];
function dfs(node) {
if (!node) return -1;
const d = Math.max(dfs(node.left), dfs(node.right)) + 1;
if (d === out.length) out.push([]);
out[d].push(node.val);
return d;
}
dfs(root);
return out;
}
class Solution {
List> out = new ArrayList<>();
public List> findLeaves(TreeNode root) { dfs(root); return out; }
int dfs(TreeNode node) {
if (node == null) return -1;
int d = Math.max(dfs(node.left), dfs(node.right)) + 1;
if (d == out.size()) out.add(new ArrayList<>());
out.get(d).add(node.val);
return d;
}
}
class Solution {
vector> out;
int dfs(TreeNode* n) {
if (!n) return -1;
int d = max(dfs(n->left), dfs(n->right)) + 1;
if (d == (int)out.size()) out.emplace_back();
out[d].push_back(n->val);
return d;
}
public:
vector> findLeaves(TreeNode* root) { dfs(root); return out; }
};
Explanation
You could literally peel off the leaves over and over, but that means scanning the whole tree again on every pass. The clever shortcut here is to figure out, in a single walk, which "removal round" each node belongs to.
The key insight is that a node's removal round equals its height above the deepest leaf below it. A real leaf has height 0 (removed first), its parent has height 1, and so on. We compute this with a post-order DFS: visit both children first, then the node.
For each node, d = max(dfs(left), dfs(right)) + 1. A missing child returns -1, so a true leaf gets max(-1, -1) + 1 = 0. We then drop the node's value into bucket out[d], growing out with a new empty list whenever we reach a brand-new depth.
Example: root = [1,2,3,4,5]. Nodes 4 and 5 are leaves so they get 0; node 3 is also a leaf so it gets 0 too. Node 2 has children at depth 0, so it becomes 1; the root 1 becomes 2. The buckets are [[4,5,3],[2],[1]].
Because each node is touched exactly once and its bucket index comes straight from its children, the whole answer is built in one O(n) sweep.