Balance a Binary Search Tree
Problem
Given the root of a binary search tree, return a balanced binary search tree containing exactly the same values. A tree is balanced if, for every node, the heights of its left and right subtrees differ by at most one. If several balanced rearrangements exist, any one of them is accepted.
root = [1,null,2,null,3,null,4][2,1,3,null,null,null,4]def balance_bst(root):
vals = []
def inorder(node):
if not node:
return
inorder(node.left)
vals.append(node.val)
inorder(node.right)
def build(lo, hi):
if lo > hi:
return None
mid = (lo + hi) // 2
node = TreeNode(vals[mid])
node.left = build(lo, mid - 1)
node.right = build(mid + 1, hi)
return node
inorder(root)
return build(0, len(vals) - 1)
function balanceBST(root) {
const vals = [];
function inorder(node) {
if (!node) return;
inorder(node.left);
vals.push(node.val);
inorder(node.right);
}
function build(lo, hi) {
if (lo > hi) return null;
const mid = (lo + hi) >> 1;
const node = new TreeNode(vals[mid]);
node.left = build(lo, mid - 1);
node.right = build(mid + 1, hi);
return node;
}
inorder(root);
return build(0, vals.length - 1);
}
class Solution {
private final List<Integer> vals = new ArrayList<>();
public TreeNode balanceBST(TreeNode root) {
inorder(root);
return build(0, vals.size() - 1);
}
private void inorder(TreeNode node) {
if (node == null) return;
inorder(node.left);
vals.add(node.val);
inorder(node.right);
}
private TreeNode build(int lo, int hi) {
if (lo > hi) return null;
int mid = (lo + hi) / 2;
TreeNode node = new TreeNode(vals.get(mid));
node.left = build(lo, mid - 1);
node.right = build(mid + 1, hi);
return node;
}
}
class Solution {
public:
TreeNode* balanceBST(TreeNode* root) {
inorder(root);
return build(0, (int)vals.size() - 1);
}
private:
vector<int> vals;
void inorder(TreeNode* node) {
if (!node) return;
inorder(node->left);
vals.push_back(node->val);
inorder(node->right);
}
TreeNode* build(int lo, int hi) {
if (lo > hi) return nullptr;
int mid = (lo + hi) / 2;
TreeNode* node = new TreeNode(vals[mid]);
node->left = build(lo, mid - 1);
node->right = build(mid + 1, hi);
return node;
}
};
Explanation
Two classic facts combine into the whole solution. First, an in-order traversal of a BST visits its values in sorted order. Second, given a sorted array, you can build a height-balanced BST by always making the middle element the root. Flatten, then rebuild — that is the entire algorithm.
Phase 1 is the flatten: a recursive in-order walk that appends every value to vals. For the example chain 1 → 2 → 3 → 4 this produces [1, 2, 3, 4]. Note that we never compare values — the BST property guarantees the list comes out sorted.
Phase 2 is the rebuild: build(lo, hi) takes a range of the sorted array, picks mid = (lo + hi) / 2, makes vals[mid] the root of that range, and recursively builds the left half as its left subtree and the right half as its right subtree. Everything left of mid is smaller and everything right is larger, so the result is still a valid BST.
Why is it balanced? At every node the two halves differ in size by at most one element, so by induction their heights differ by at most one — exactly the balance condition. Halving the range at each level also bounds the total height at O(log n).
On the default example: build(0, 3) picks mid index 1, so 2 becomes the root. The left half [0..0] places 1 as its left child; the right half [2..3] picks mid 2, placing 3 as the right child, and finally [3..3] hangs 4 under 3. The height drops from 3 to 2.
The traversal touches each node once and the rebuild creates each node once, so the time is O(n). The vals array dominates the space at O(n) (the recursion stack adds at most the original tree's height).