Range Sum Query 2D - Mutable
Problem
Support update(r, c, val) and sumRegion(r1, c1, r2, c2) on a mutable matrix.
m = [[3,0,1],[5,6,3]]; sumRegion(0,0,1,2); update(0,1,4); sumRegion(0,0,1,2)[18, 22]class NumMatrix:
def __init__(self, m):
self.R, self.C = len(m), len(m[0])
self.bit = [[0]*(self.C+1) for _ in range(self.R+1)]
self.a = [[0]*self.C for _ in range(self.R)]
for r in range(self.R):
for c in range(self.C):
self.update(r, c, m[r][c])
def update(self, r, c, val):
diff = val - self.a[r][c]; self.a[r][c] = val
i = r + 1
while i <= self.R:
j = c + 1
while j <= self.C: self.bit[i][j] += diff; j += j & -j
i += i & -i
def _pref(self, r, c):
s = 0; i = r
while i > 0:
j = c
while j > 0: s += self.bit[i][j]; j -= j & -j
i -= i & -i
return s
def sumRegion(self, r1, c1, r2, c2):
return (self._pref(r2+1, c2+1) - self._pref(r1, c2+1)
- self._pref(r2+1, c1) + self._pref(r1, c1))
class NumMatrix {
constructor(m) {
this.R = m.length; this.C = m[0].length;
this.bit = Array.from({length: this.R + 1}, () => new Array(this.C + 1).fill(0));
this.a = Array.from({length: this.R}, () => new Array(this.C).fill(0));
for (let r = 0; r < this.R; r++) for (let c = 0; c < this.C; c++) this.update(r, c, m[r][c]);
}
update(r, c, val) {
const diff = val - this.a[r][c]; this.a[r][c] = val;
for (let i = r + 1; i <= this.R; i += i & -i)
for (let j = c + 1; j <= this.C; j += j & -j) this.bit[i][j] += diff;
}
_pref(r, c) {
let s = 0;
for (let i = r; i > 0; i -= i & -i)
for (let j = c; j > 0; j -= j & -j) s += this.bit[i][j];
return s;
}
sumRegion(r1, c1, r2, c2) {
return this._pref(r2+1, c2+1) - this._pref(r1, c2+1) - this._pref(r2+1, c1) + this._pref(r1, c1);
}
}
class NumMatrix {
int R, C; int[][] bit, a;
public NumMatrix(int[][] m) {
R = m.length; C = m[0].length;
bit = new int[R + 1][C + 1]; a = new int[R][C];
for (int r = 0; r < R; r++) for (int c = 0; c < C; c++) update(r, c, m[r][c]);
}
public void update(int r, int c, int val) {
int diff = val - a[r][c]; a[r][c] = val;
for (int i = r + 1; i <= R; i += i & -i)
for (int j = c + 1; j <= C; j += j & -j) bit[i][j] += diff;
}
int pref(int r, int c) {
int s = 0;
for (int i = r; i > 0; i -= i & -i) for (int j = c; j > 0; j -= j & -j) s += bit[i][j];
return s;
}
public int sumRegion(int r1, int c1, int r2, int c2) {
return pref(r2+1, c2+1) - pref(r1, c2+1) - pref(r2+1, c1) + pref(r1, c1);
}
}
class NumMatrix {
int R, C; vector> bit, a;
public:
NumMatrix(vector>& m) {
R = m.size(); C = m[0].size();
bit.assign(R + 1, vector(C + 1, 0));
a.assign(R, vector(C, 0));
for (int r = 0; r < R; r++) for (int c = 0; c < C; c++) update(r, c, m[r][c]);
}
void update(int r, int c, int val) {
int diff = val - a[r][c]; a[r][c] = val;
for (int i = r + 1; i <= R; i += i & -i)
for (int j = c + 1; j <= C; j += j & -j) bit[i][j] += diff;
}
int pref(int r, int c) { int s = 0; for (int i = r; i > 0; i -= i & -i) for (int j = c; j > 0; j -= j & -j) s += bit[i][j]; return s; }
int sumRegion(int r1, int c1, int r2, int c2) { return pref(r2+1, c2+1) - pref(r1, c2+1) - pref(r2+1, c1) + pref(r1, c1); }
};
Explanation
We need both fast updates and fast rectangle sums on a grid. A plain prefix-sum array gives fast sums but slow updates; this solution uses a 2D Fenwick tree (Binary Indexed Tree) to make both fast.
A Fenwick tree stores partial sums at clever positions chosen by the lowbit trick i & -i (the lowest set bit). To update(r, c, val) we first compute the change diff from the old value, then walk i and j upward (i += i & -i), adding diff to every bucket responsible for that cell — two nested lowbit chains.
The helper _pref(r, c) walks downward (i -= i & -i) summing buckets to get the total of the rectangle from the origin to (r, c). To get an arbitrary rectangle, sumRegion uses inclusion-exclusion over four such prefix queries: the big box, minus the strip above, minus the strip to the left, plus the top-left corner added back (because it was subtracted twice).
Example: for m = [[3,0,1],[5,6,3]], sumRegion(0,0,1,2) totals the whole grid = 18. After update(0,1,4) raises the 0 to 4, the diff of +4 ripples through the tree and the same query now returns 22.
Each chain is logarithmic and there are two nested chains, so every operation is O(log² n), using O(RC) space.