对于给定数组nums,修改nums中某个下标的值的操作记做update(index, val),获得nums某个区间元素和的操作记做query(start,end)。
若直接对原数组nums操作,update操作时间复杂度为O(1),query操作为O(N),引入前缀和数组之后,query的时间复杂度是变为O(1),但是update的时间复杂度又变为O(N)。
为了降低上述两操作的平均时间复杂度,引入线段树这种数据结构,使得update 和 query的时间复杂度都变为O(log(N))。
线段树的每个节点存储某一个段区间之和,其中每个结点的左子树和右子树分别存储当前结点的前半段之和和后半段之和,叶子结点存储的线段长度为1,根结点存储整个数组之和。
如下举例说明:
对于nums = [1, 2, 3, 4, 5, 6],线段树结构如下图所示:
由于我们发现其构成的线段树类似完全二叉树。因此可以使用像大/小根堆中的存储二叉树的方式存储该树。left = parent * 2 + 1,right = parent * 2 + 2。
一般使用4倍的原数组的大小存储该树。
对于当前结点,首先完成左孩子和右孩子的创建,之后其的值等于左右孩子值之和。
baseline为当前结点为叶子结点时,当前结点值即为nums元素值。
public class IntervalTree {
int[] nums;
int[] tree;
public IntervalTree(int[] nums) {
this.nums = nums;
this.tree = new int[nums.length * 4];
create(0, 0, nums.length - 1);
}
// left 和 right为nodeIndex对应的nums的线段区间
public void create(int nodeIndex, int left, int right) {
if(left == right) {
tree[nodeIndex] = nums[left];
return;
}
int mid = (left + right) / 2;
create(nodeIndex * 2 + 1, left, mid);
create(nodeIndex * 2 + 2, mid + 1, right);
tree[nodeIndex] = tree[nodeIndex * 2 + 1] + tree[nodeIndex * 2 + 2];
}
}
直观过程是修改当前位置对应数中的叶子结点,然后依次一层一层网上遍历,修改其父亲节点对应的值。
还是使用递归求解,代码与建树过程类似,不过需要注意的是不需要走完全树,只需走完对应的部分即可。
public void update(int index, int val) {
update(0, 0, nums.length - 1, index, val);
}
public void update(int nodeIndex, int left, int right, int index, int val) {
if(left == right) {
tree[nodeIndex] = nums[left] = val;
return;
}
int mid = (left + right) / 2;
if(index <= mid) {
update(nodeIndex * 2 + 1, left, mid, index, val);
}else {
update(nodeIndex * 2 + 2, mid + 1, right, index, val);
}
tree[nodeIndex] = tree[nodeIndex * 2 + 1] + tree[nodeIndex * 2 + 2];
}
代码类似建树过程,不过只需计算与当前区间有交集的部分。
public int query(int start, int end) {
return query(0, 0, nums.length - 1, start, end);
}
public int query(int nodeIndex, int left, int right, int start, int end) {
if(start > right || end < left) {
return 0;
}
if(left >= start && right <= end) {
return tree[nodeIndex];
}
int mid = (left + right) / 2;
return query(nodeIndex * 2 + 1, left, mid, start, end)
+ query(nodeIndex * 2 + 2, mid + 1, right, start, end);
}
完整代码如下:
public class IntervalTree {
int[] nums;
int[] tree;
public IntervalTree(int[] nums) {
this.nums = nums;
this.tree = new int[nums.length * 4];
create(0, 0, nums.length - 1);
}
public void create(int nodeIndex, int left, int right) {
if(left == right) {
tree[nodeIndex] = nums[left];
return;
}
int mid = (left + right) / 2;
create(nodeIndex * 2 + 1, left, mid);
create(nodeIndex * 2 + 2, mid + 1, right);
tree[nodeIndex] = tree[nodeIndex * 2 + 1] + tree[nodeIndex * 2 + 2];
}
public void update(int index, int val) {
update(0, 0, nums.length - 1, index, val);
}
public void update(int nodeIndex, int left, int right, int index, int val) {
if(left == right) {
tree[nodeIndex] = nums[left] = val;
return;
}
int mid = (left + right) / 2;
if(index <= mid) {
update(nodeIndex * 2 + 1, left, mid, index, val);
}else {
update(nodeIndex * 2 + 2, mid + 1, right, index, val);
}
tree[nodeIndex] = tree[nodeIndex * 2 + 1] + tree[nodeIndex * 2 + 2];
}
public int query(int start, int end) {
return query(0, 0, nums.length - 1, start, end);
}
public int query(int nodeIndex, int left, int right, int start, int end) {
if(start > right || end < left) {
return 0;
}
if(left >= start && right <= end) {
return tree[nodeIndex];
}
int mid = (left + right) / 2;
return query(nodeIndex * 2 + 1, left, mid, start, end)
+ query(nodeIndex * 2 + 2, mid + 1, right, start, end);
}
}