Trim a Binary Search Tree
Problem
Given the root of a binary search tree and the lowest and highest boundaries as low and high, trim the tree so that all its elements lie in [low, high]. Return the new root after trimming.
root = [3,0,4,null,2,null,null,1], low = 1, high = 3[3,2,null,1]def trim_bst(root, low, high):
if not root:
return None
if root.val < low:
return trim_bst(root.right, low, high)
if root.val > high:
return trim_bst(root.left, low, high)
root.left = trim_bst(root.left, low, high)
root.right = trim_bst(root.right, low, high)
return root
function trimBST(root, low, high) {
if (!root) return null;
if (root.val < low) return trimBST(root.right, low, high);
if (root.val > high) return trimBST(root.left, low, high);
root.left = trimBST(root.left, low, high);
root.right = trimBST(root.right, low, high);
return root;
}
class Solution {
public TreeNode trimBST(TreeNode root, int low, int high) {
if (root == null) return null;
if (root.val < low) return trimBST(root.right, low, high);
if (root.val > high) return trimBST(root.left, low, high);
root.left = trimBST(root.left, low, high);
root.right = trimBST(root.right, low, high);
return root;
}
}
TreeNode* trimBST(TreeNode* root, int low, int high) {
if (!root) return nullptr;
if (root->val < low) return trimBST(root->right, low, high);
if (root->val > high) return trimBST(root->left, low, high);
root->left = trimBST(root->left, low, high);
root->right = trimBST(root->right, low, high);
return root;
}
Explanation
We want to drop every node whose value falls outside [low, high], but keep the tree a valid BST. The BST ordering lets us prune whole branches at once instead of checking nodes one by one.
The recursion trim_bst(root, low, high) looks at the current node. If root.val < low, then the root and its entire left subtree are too small, so we throw them away and return the trimmed right subtree. Symmetrically, if root.val > high, we discard the root and its right subtree and return the trimmed left subtree.
If the node is inside the range, we keep it and simply recurse into both children, reassigning root.left and root.right to their trimmed versions before returning the node.
Example: [3,0,4,null,2,null,null,1] with low=1, high=3. Node 0 is too small, so we keep its right side, which surfaces 2 (and its child 1) to reattach. Node 4 is too big and is dropped. The result is [3,2,null,1].
Each node is processed once, so trimming runs in O(n) time.