0%

528. Random Pick with Weight

O(logn) time O(n) space
[1, 3]是频数,即构造成下标数组[0, 1, 1, 1]然后随机一个index
累加频数,构造频数和数组[1, 4],这里最后一个4是所有频数的和,即数组[0, 1, 1, 1]的长度,随机以后得到一个下标,需要得到下标所对应的原数组的index,因为频数和数组是递增的,所以二分可得到对应的频数和位置,即为原频数数组的下标

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
public:
Solution(vector<int>& w) {
v = w;
partial_sum(begin(v), end(v), begin(v), plus<int>());
srand(time(NULL));
}

int pickIndex() {
return upper_bound(begin(v), end(v), rand() % v.back()) - begin(v);
}

vector<int> v;
};

/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(w);
* int param_1 = obj->pickIndex();
*/
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
public:
Solution(vector<int>& w) {
v = w;
for (int i = 1; i < v.size(); ++i) {
v[i] += v[i - 1];
}
srand((unsigned)time(NULL));
}

int pickIndex() {
return lower_bound(begin(v), end(v), (rand() % v.back() + 1)) - begin(v); // 加1恢复成1-indexed
}

vector<int> v;
};

/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(w);
* int param_1 = obj->pickIndex();
*/
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
public:
Solution(vector<int>& w) {
int sum = 0;
for (int x : w) {
v.push_back(sum += x); // 累加频数
}
srand((unsigned)time(NULL));
}

int pickIndex() {
return distance(begin(v), upper_bound(begin(v), end(v), rand() % v.back())); // 这里要用upper_bound因为rand出来的是0-indexed,而频数本身是1-indexed,用lower_bound会找错,举例[1, 4]rand % 4出来的是下标1,而不是频数1,upper_bound找到的是频数和4,而lower_bound找到的是错误的频数和1
}

vector<int> v;
};

/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(w);
* int param_1 = obj->pickIndex();
*/

follow-up如果weight数组是mutable的
用线段树,因为需要前缀和,把二分查找前缀和的上界改成二分查找区间和的上界,只是这个区间和是从0开始的
update O(logn)
query O(logn)
pickIndex O(logn*logn)
需要注意update和pickIndex的比例,如果很少update也可以考虑普通前缀和来做这样update虽然是O(n)但是pickIndex是O(logn)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution {
public:
Solution(vector<int>& w) {
n = w.size();
v.resize(n * 2);
copy(begin(w), end(w), begin(v) + n);
for (int i = n - 1; i > 0; --i) {
v[i] = v[i * 2] + v[i * 2 + 1];
}
srand(time(NULL));
}

void update(int i, int val) {
i += n;
for (int d = val - v[i]; i > 0; i >>= 1) {
v[i] += d;
}
}

int query(int i) {
int res = 0;
for (int l = n, r = i + n; l <= r; l >>= 1, r >>= 1) {
if (l & 1) {
res += v[l++];
}
if ((r & 1) == 0) {
res += v[r--];
}
}
return res;
}

int pickIndex() {
int x = rand() % v[1], l = 0, r = n - 1;
while (l < r) {
int m = l + (r - l) / 2;
if (query(m) <= x) { // <是找下界<=是找上界c
l = m + 1;
} else {
r = m;
}
}
return l;
}

vector<int> v;
int n;
};

/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(w);
* int param_1 = obj->pickIndex();
*/