0%

307. Range Sum Query - Mutable

常规线段树
指针版

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class NumArray {
struct Node {
Node(int b, int e, int val, Node *l = nullptr, Node *r = nullptr)
: b(b), e(e), val(val), l(l), r(r) {
}

int b, e, val;
Node *l, *r;
};
public:
NumArray(vector<int>& nums) {
if (nums.empty()) return;
n = nums.size();
root = build(nums, 0, n - 1);
}

Node *build(const vector<int> &A, int l, int r) {
if (l == r) {
return new Node(l, r, A[l]);
} else {
int m = l + (r - l) / 2;
auto tl = build(A, l, m);
auto tr = build(A, m + 1, r);
return new Node(l, r, tl->val + tr->val, tl, tr);
}
}

void update(int i, int val) {
update(root, i, val);
}

void update(Node *p, int i, int val) {
if (p->b == i && p->e == i) {
p->val = val;
} else {
int m = p->b + (p->e - p->b) / 2;
if (i <= m) {
update(p->l, i, val);
} else {
update(p->r, i, val);
}
p->val = p->l->val + p->r->val;
}
}

int sumRange(int i, int j) {
return query(root, i, j);
}

int query(Node *p, int l, int r) {
if (l > r) return 0;
if (p->b == l && p->e == r) return p->val;
int m = p->b + (p->e - p->b) / 2;
return query(p->l, l, min(m, r)) + query(p->r, max(l, m + 1), r);
}

int n;
Node *root = nullptr;
};

/**
* Your NumArray object will be instantiated and called as such:
* NumArray* obj = new NumArray(nums);
* obj->update(i,val);
* int param_2 = obj->sumRange(i,j);
*/

数组版

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class NumArray {
public:
NumArray(vector<int>& nums) {
if (nums.empty()) return;
n = nums.size();
tree.resize(4 * n); // 开大小为4n的数组
build(1, nums, 0, n - 1);
}

void build(int v, const vector<int> &A, int l, int r) {
if (l == r) {
tree[v] = A[l];
} else {
int m = l + (r - l) / 2;
build(v * 2, A, l, m);
build(v * 2 + 1, A, m + 1, r);
tree[v] = tree[v * 2] + tree[v * 2 + 1];
}
}

void update(int i, int val) {
update(1, 0, n - 1, i, val);
}

void update(int v, int l, int r, int p, int val) {
if (l == r) {
tree[v] = val;
} else {
int m = l + (r - l) / 2;
if (p <= m) {
update(v * 2, l, m, p, val);
} else {
update(v * 2 + 1, m + 1, r, p, val);
}
tree[v] = tree[v * 2] + tree[v * 2 + 1];
}
}

int sumRange(int i, int j) {
return query(1, 0, n - 1, i, j);
}

int query(int v, int tl, int tr, int l, int r) {
if (l > r) return 0;
if (tl == l && tr == r) return tree[v];
int m = tl + (tr - tl) / 2;
return query(v * 2, tl, m, l, min(m, r)) + query(v * 2 + 1, m + 1, tr, max(m + 1, l), r);
}

int n;
vector<int> tree;
};

/**
* Your NumArray object will be instantiated and called as such:
* NumArray* obj = new NumArray(nums);
* obj->update(i,val);
* int param_2 = obj->sumRange(i,j);
*/

zkw segment tree O(n) constructor O(logn) update O(logn) sum O(n) space
只支持bottom-up的题,top-down的题一律用常规线段树做
用数组来表示线段树
原始数组为[1, 2, 3, 4]则数组为[0, 10, 3, 7, 1, 2, 3, 4],
原始数组为[1, 2, 3]则数组为[0, 6, 5, 1, 2, 3]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class NumArray {
public:
NumArray(vector<int>& nums) {
n = nums.size();
tree.resize(2 * n); // 开大小为2n的数组
copy(begin(nums), end(nums), begin(tree) + n); // 把原始数组的n个数放在最后
for (int i = n - 1; i > 0; --i) { // 自底向上累加
tree[i] = tree[2 * i] + tree[2 * i + 1];
}
}

void update(int i, int val) {
i += n;
for (int d = val - tree[i]; i > 0; i >>= 1) { // 从叶节点开始用新旧数差值d自底向上来更新线段树
tree[i] += d;
}
}

int sumRange(int i, int j) {
int res = 0;
for (i += n, j += n; i <= j; i >>= 1, j >>= 1) { // 分别找到两个叶节点的位置,自底向上『递归』
if (i & 1) { // 如果左边界是奇数,则证明是一个右子树,直接累加并右移一个节点,如果左边界是偶数,则证明是一个左子树,其父节点是左右两子树之和,继续向上递归即可
res += tree[i++];
}
if ((j & 1) == 0) { // 如果右边界是偶数,则证明是一个左子树,直接累加并左移一个节点,如果右边界是奇数,则证明是一个右子树,其父节点是左右两子树之和,继续向上递归即可
res += tree[j--];
}
}
return res;
}

int n;
vector<int> tree;
};

/**
* Your NumArray object will be instantiated and called as such:
* NumArray* obj = new NumArray(nums);
* obj->update(i,val);
* int param_2 = obj->sumRange(i,j);
*/

树状数组
constructor O(nlogn)
update O(logn)
sumRange O(logn)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class NumArray {
public:
NumArray(vector<int> nums) {
v = nums;
int n = nums.size();
BIT.resize(n + 1);
for (int i = 0; i < n; ++i) {
helper(i, nums[i]);
}
}

int query(int i) {
int sum = 0;
++i;
while (i > 0) {
sum += BIT[i];
i -= (i & -i);
}
return sum;
}

void update(int i, int val) {
helper(i, val - v[i]);
v[i] = val;
}

void helper(int i, int val) {
++i;
while (i < BIT.size()) {
BIT[i] += val;
i += (i & -i);
}
}

int sumRange(int i, int j) {
return query(j) - query(i - 1);
}

vector<int> BIT, v;
};

/**
* Your NumArray object will be instantiated and called as such:
* NumArray obj = new NumArray(nums);
* obj.update(i,val);
* int param_2 = obj.sumRange(i,j);
*/

二维树状数组求区间和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <iostream>
#include <vector>

using namespace std;

struct Solution {
Solution(const vector<vector<int>> &mtx) : mtx(mtx) {
int n = mtx.size(), m = mtx[0].size();
BIT.resize(n + 1, vector<int>(m + 1));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
helper(i, j, mtx[i][j]);
}
}
}

int query(int i, int j) {
++i, ++j;
int sum = 0;
while (i > 0) {
int t = j; // 注意要保存j因为每次内循环要用原值
while (t > 0) {
sum += BIT[i][t];
t -= (t & -t);
}
i -= (i & -i);
}
return sum;
}

int query(int uli, int ulj, int lri, int lrj) {
return query(lri, lrj) - query(lri, ulj - 1) - query(uli - 1, lrj) + query(uli - 1, ulj - 1); // 注意跟一维的不一样,不能只减左上角还有两边的矩阵和也要减
}

void helper(int i, int j, int val) {
++i, ++j;
while (i < BIT.size()) {
int t = j;
while (t < BIT[i].size()) {
BIT[i][t] += val;
t += (t & -t);
}
i += (i & -i);
}
}

void update(int i, int j, int val) {
helper(i, j, val - mtx[i][j]);
mtx[i][j] = val;
}

vector<vector<int>> mtx, BIT;
};

int main() {
// your code goes here
vector<vector<int>> mtx = {
{3, 2, 2, 8, 1, 6, 4},
{4, 2, 7, 0, 2, 8, 1},
{2, 9, 1, 6, 5, 5, 5},
{2, 7, 4, 4, 1, 4, 8},
{0, 4, 7, 1, 2, 5, 8},
{7, 2, 8, 2, 1, 6, 9}
};
Solution s(mtx);
cout << s.query(0, 0, 5, 6) << endl;
cout << s.query(2, 3, 3, 5) << endl;
s.update(4, 2, 8);
cout << s.query(0, 0, 5, 6) << endl;
cout << s.query(1, 2, 4, 5) << endl;

return 0;
}