Count of Smaller Numbers After Self
Problem
For each index i, count how many nums[j] (j > i) are strictly smaller than nums[i]. Return the counts array.
nums = [5, 2, 6, 1][2, 1, 1, 0]def count_smaller(nums):
n = len(nums)
counts = [0] * n
idx = list(range(n))
def merge_sort(l, r):
if r - l <= 1: return
m = (l + r) // 2
merge_sort(l, m); merge_sort(m, r)
merged = []; i = l; j = m; right_smaller = 0
while i < m and j < r:
if nums[idx[j]] < nums[idx[i]]:
merged.append(idx[j]); j += 1; right_smaller += 1
else:
counts[idx[i]] += right_smaller
merged.append(idx[i]); i += 1
while i < m:
counts[idx[i]] += right_smaller
merged.append(idx[i]); i += 1
while j < r:
merged.append(idx[j]); j += 1
idx[l:r] = merged
merge_sort(0, n)
return counts
function countSmaller(nums) {
const n = nums.length;
const counts = new Array(n).fill(0);
const idx = Array.from({ length: n }, (_, i) => i);
function mergeSort(l, r) {
if (r - l <= 1) return;
const m = (l + r) >> 1;
mergeSort(l, m); mergeSort(m, r);
const merged = [];
let i = l, j = m, right = 0;
while (i < m && j < r) {
if (nums[idx[j]] < nums[idx[i]]) { merged.push(idx[j++]); right++; }
else { counts[idx[i]] += right; merged.push(idx[i++]); }
}
while (i < m) { counts[idx[i]] += right; merged.push(idx[i++]); }
while (j < r) merged.push(idx[j++]);
for (let k = 0; k < merged.length; k++) idx[l + k] = merged[k];
}
mergeSort(0, n);
return counts;
}
class Solution {
int[] nums; Integer[] idx; int[] counts;
public List<Integer> countSmaller(int[] n) {
nums = n; int N = n.length;
counts = new int[N]; idx = new Integer[N];
for (int i = 0; i < N; i++) idx[i] = i;
mergeSort(0, N);
List<Integer> out = new ArrayList<>();
for (int c : counts) out.add(c);
return out;
}
void mergeSort(int l, int r) {
if (r - l <= 1) return;
int m = (l + r) / 2;
mergeSort(l, m); mergeSort(m, r);
Integer[] merged = new Integer[r - l];
int i = l, j = m, k = 0, right = 0;
while (i < m && j < r) {
if (nums[idx[j]] < nums[idx[i]]) { merged[k++] = idx[j++]; right++; }
else { counts[idx[i]] += right; merged[k++] = idx[i++]; }
}
while (i < m) { counts[idx[i]] += right; merged[k++] = idx[i++]; }
while (j < r) merged[k++] = idx[j++];
for (int p = 0; p < merged.length; p++) idx[l + p] = merged[p];
}
}
vector<int> nums_, idx_, counts_;
void mergeSort(int l, int r) {
if (r - l <= 1) return;
int m = (l + r) / 2;
mergeSort(l, m); mergeSort(m, r);
vector<int> merged;
int i = l, j = m, right = 0;
while (i < m && j < r) {
if (nums_[idx_[j]] < nums_[idx_[i]]) { merged.push_back(idx_[j++]); right++; }
else { counts_[idx_[i]] += right; merged.push_back(idx_[i++]); }
}
while (i < m) { counts_[idx_[i]] += right; merged.push_back(idx_[i++]); }
while (j < r) merged.push_back(idx_[j++]);
for (int k = 0; k < (int)merged.size(); k++) idx_[l + k] = merged[k];
}
vector<int> countSmaller(vector<int>& nums) {
nums_ = nums; int n = nums.size();
idx_.assign(n, 0); for (int i = 0; i < n; i++) idx_[i] = i;
counts_.assign(n, 0);
mergeSort(0, n);
return counts_;
}
Explanation
For each element we need how many later elements are strictly smaller. The brute-force O(n²) scan is slow, so we piggyback on a merge sort that sorts indices rather than values, counting "jumps" as it merges.
We keep an array idx of positions and sort it by the values in nums. During the merge of a left half and a right half, every time we pull an element from the right half, it means that right element is smaller than the left elements still waiting.
We track a counter right of how many right-half elements have already been placed. When we finally take a left element idx[i], we add right to counts[idx[i]] — those right elements all came from later positions and were smaller.
Sorting by index (not value) preserves the original positions, so "later" stays meaningful even as the order gets shuffled. The merge structure makes it O(n log n).
Example: nums = [5, 2, 6, 1] gives [2, 1, 1, 0] — 5 has two smaller after it (2 and 1), 2 has one (1), 6 has one (1), and 1 has none.