Number of Good Leaf Nodes Pairs

medium binary tree dfs

Problem

You are given the root of a binary tree and an integer distance. Two different leaf nodes form a good pair if the shortest path between them uses at most distance edges. Count how many good leaf pairs the tree contains.

The tree has at most 210 nodes and 1 ≤ distance ≤ 10, so a per-node array of leaf counts by depth stays tiny.

Inputroot = [1,2,3,4,5,6,7], distance = 3
Output2
Leaf pairs (4, 5) and (6, 7) are each 2 edges apart; every pair crossing the root, like (4, 6), needs 4 edges — over the budget.

def count_pairs(root, distance):
    total = 0
    def dfs(node):
        nonlocal total
        if not node:
            return []
        if not node.left and not node.right:
            return [1]                      # one leaf at depth 0
        left = dfs(node.left)
        right = dfs(node.right)
        for i, lc in enumerate(left):       # pair left x right leaves
            for j, rc in enumerate(right):
                if i + j + 2 <= distance:
                    total += lc * rc
        up = [0] * distance                 # shift depths up one edge
        for i in range(min(len(left), distance - 1)):
            up[i + 1] += left[i]
        for i in range(min(len(right), distance - 1)):
            up[i + 1] += right[i]
        return up
    dfs(root)
    return total
function countPairs(root, distance) {
  let total = 0;
  function dfs(node) {
    if (!node) return [];
    if (!node.left && !node.right) return [1];  // one leaf at depth 0
    const left = dfs(node.left), right = dfs(node.right);
    for (let i = 0; i < left.length; i++)       // pair left x right
      for (let j = 0; j < right.length; j++)
        if (i + j + 2 <= distance) total += left[i] * right[j];
    const up = new Array(distance).fill(0);     // shift depths one edge
    for (let i = 0; i < Math.min(left.length, distance - 1); i++)
      up[i + 1] += left[i];
    for (let i = 0; i < Math.min(right.length, distance - 1); i++)
      up[i + 1] += right[i];
    return up;
  }
  dfs(root);
  return total;
}
class Solution {
    int total = 0;
    public int countPairs(TreeNode root, int distance) {
        dfs(root, distance);
        return total;
    }
    int[] dfs(TreeNode node, int distance) {
        if (node == null) return new int[0];
        if (node.left == null && node.right == null)
            return new int[]{1};               // one leaf at depth 0
        int[] left = dfs(node.left, distance);
        int[] right = dfs(node.right, distance);
        for (int i = 0; i < left.length; i++)  // pair left x right
            for (int j = 0; j < right.length; j++)
                if (i + j + 2 <= distance) total += left[i] * right[j];
        int[] up = new int[distance];          // shift depths one edge
        for (int i = 0; i < Math.min(left.length, distance - 1); i++)
            up[i + 1] += left[i];
        for (int i = 0; i < Math.min(right.length, distance - 1); i++)
            up[i + 1] += right[i];
        return up;
    }
}
class Solution {
public:
    int total = 0;
    int countPairs(TreeNode* root, int distance) {
        dfs(root, distance);
        return total;
    }
    vector<int> dfs(TreeNode* node, int distance) {
        if (!node) return {};
        if (!node->left && !node->right) return {1}; // leaf at depth 0
        auto left = dfs(node->left, distance);
        auto right = dfs(node->right, distance);
        for (int i = 0; i < (int)left.size(); i++)   // pair left x right
            for (int j = 0; j < (int)right.size(); j++)
                if (i + j + 2 <= distance) total += left[i] * right[j];
        vector<int> up(distance, 0);                 // shift one edge
        for (int i = 0; i + 1 < distance && i < (int)left.size(); i++)
            up[i + 1] += left[i];
        for (int i = 0; i + 1 < distance && i < (int)right.size(); i++)
            up[i + 1] += right[i];
        return up;
    }
};
Time: O(n · distance²) Space: O(n)