Binary Trees With Factors
Problem
Given a list of unique integers ≥ 2, count distinct binary trees where each non-leaf node's value equals the product of its two children's values. Each integer in the list may be reused as many times as you want. Return the count modulo 10⁹ + 7.
arr = [2, 4]3def num_factored_binary_trees(arr):
MOD = 10**9 + 7
arr.sort()
dp = {x: 1 for x in arr}
s = set(arr)
for x in arr:
for a in arr:
if a * a > x:
break
if x % a == 0 and (x // a) in s:
b = x // a
ways = dp[a] * dp[b]
if a != b:
ways *= 2
dp[x] = (dp[x] + ways) % MOD
return sum(dp.values()) % MOD
function numFactoredBinaryTrees(arr) {
const MOD = 1_000_000_007n;
arr.sort((a, b) => a - b);
const dp = new Map();
for (const x of arr) dp.set(x, 1n);
const s = new Set(arr);
for (const x of arr) {
for (const a of arr) {
if (a * a > x) break;
if (x % a === 0 && s.has(x / a)) {
const b = x / a;
let ways = dp.get(a) * dp.get(b);
if (a !== b) ways *= 2n;
dp.set(x, (dp.get(x) + ways) % MOD);
}
}
}
let total = 0n;
for (const v of dp.values()) total = (total + v) % MOD;
return Number(total);
}
class Solution {
public int numFactoredBinaryTrees(int[] arr) {
long MOD = 1_000_000_007L;
Arrays.sort(arr);
Map<Integer, Long> dp = new HashMap<>();
Set<Integer> s = new HashSet<>();
for (int x : arr) { dp.put(x, 1L); s.add(x); }
for (int x : arr) {
for (int a : arr) {
if ((long)a * a > x) break;
if (x % a == 0 && s.contains(x / a)) {
int b = x / a;
long ways = dp.get(a) * dp.get(b) % MOD;
if (a != b) ways = ways * 2 % MOD;
dp.put(x, (dp.get(x) + ways) % MOD);
}
}
}
long total = 0;
for (long v : dp.values()) total = (total + v) % MOD;
return (int) total;
}
}
class Solution {
public:
int numFactoredBinaryTrees(vector<int>& arr) {
const long MOD = 1'000'000'007L;
sort(arr.begin(), arr.end());
unordered_map<int, long> dp;
unordered_set<int> s;
for (int x : arr) { dp[x] = 1; s.insert(x); }
for (int x : arr) {
for (int a : arr) {
if ((long)a * a > x) break;
if (x % a == 0 && s.count(x / a)) {
int b = x / a;
long ways = dp[a] * dp[b] % MOD;
if (a != b) ways = ways * 2 % MOD;
dp[x] = (dp[x] + ways) % MOD;
}
}
}
long total = 0;
for (auto& kv : dp) total = (total + kv.second) % MOD;
return (int) total;
}
};
Explanation
Each non-leaf node's value must equal its two children multiplied together. So we count trees by their root value: dp[x] = how many valid trees have x at the root.
Every value can stand alone as a single-node tree, so each dp[x] starts at 1. We process values in sorted order so that whenever we build a tree rooted at x, both children (which are smaller factors) are already counted.
For root x we look for a factor pair a * b = x where both a and b are in the array. Each such pair contributes dp[a] * dp[b] new trees (any left-shape times any right-shape). If a != b we double it, because the two children can swap sides. We add these into dp[x] under the modulo.
The inner loop stops early with if a * a > x: break since beyond the square root the factor pairs just repeat in mirror. The final answer is the sum of all dp[x].
Example: arr = [2, 4]. dp[2] = 1. For 4, the pair 2 · 2 adds dp[2] * dp[2] = 1, so dp[4] = 2. Total 1 + 2 = 3.