leetCode Question: Range Sum Query - Mutable

Range Sum Query - Mutable

Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.

The update(i, val) function modifies nums by updating the element at index i to val.
Example:
Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
Note:
The array is only modifiable by the update function.
You may assume the number of calls to update and sumRange function is distributed evenly.

Analysis

An intuitive but less efficient solution of this problem is to use simple loop:

  • loop from index i to j, accumulate the sum and return.

It is not difficult to find out the time complexity is O(n) for finding the sum and O(1) for the update.

Another solution: use the data structure segment tree. In this post, we are going to briefly introduce this data structure and focus on how to implement it in an intuitive way.

Specifically in our problem, segment tree can be viewed as a binary tree, where:

  • Leaf nodes are the elements of input array.
  • inernal nodes are the some merging of the leaf nodes. In this problem, internal nodes are the sum of leaf nodes under it.

An example of segment tree is shown below in the figure:

Now I will try to answer the following questions:

  • How does this help solving our problem?
  • How to construct the segment tree?
  • How to compute the range sum?
  • How to update the segment tree?

How does this help solving our problem?
From the tree structure shown above, we can see that, for each internal node (not leaf nodes), we already comput certain range sum (in a binary saerch fashion). If we could utilize these ranges sums to compute any range sums, it will be much more efficient than using loop. So what shall we do to compute the range sum? Don't worry, we will discuss this later.

How to construct the segment tree?
As we said, given an input array, we want to construct a tree structure where (1) Leaf nodes are the elements of input array. (2) Inernal nodes are the some merging of the leaf nodes. In this problem, internal nodes are the sum of leaf nodes under it.

Binary search is a good way constructing the tree. Specifically, we have a root node and an input array, the value of root node is the sum of its left and right children's value. The left and right children are also a segment tree, where the input array now becomes the left half and right half of the original array. Now we can build the segment tree using recursion.

How to compute the range sum?
We have a segment tree, the goal is to compute the range sum given the start and end indices. As the definition of segment tree, we have a range [st, ed] with each node, which represent the node value is actually the sum of range [st, ed].
Say now we have a query to compute range sum in [ i, j ]:

  1. If [st, ed] and [ i, j ] are identical, so the node value is what we want.
  2. If [st, ed] is totally inside the range [ i, j ], so current node's value is part of our sum, but it is not enough. We also have to add the sum in [ i, st-1 ], and [ ed+1, j ].
  3. If [st, ed] is totally outside the range [ i, j ], current range has no relation to our goal, we just ignore the current node (and tree nodes below it).
  4. If [st, ed] has partial overlap with range [ i, j ]. We shoud keep search the left and right chrildren of current node.

After listing all the possibilities of our query range [ i, j ], and any range corresponding to one tree node in the segment tree, we could write the algorithm to find the range sum just using the tree search. (see the figure below)

The time complexity of computing sum now becomes O(log n), where n is the number of elements in input array. The time complexity for constructing the segment tree is O(n).

How to update the segment tree?
Updating tree node is pretty straight forward, we just have to find the path goes from root node to the specific node we want to update, and update each node value through the path by delta, where delta is the difference between the new value and the original value. Because every node through the path, records the sum of range where the new leaf node to lies in, so we have to update all its values when updating this leaf node.
The time complexity of updating leaf node is also O(log n).

Code (C++):

class NumArray {
public:
NumArray(vector<int> nums) {
n = nums.size();
if (n==0){return;}
this->nums = nums;
segTree = constructSegTree(nums, 0, n-1);
//checkSegTree(segTree);
}
void update(int i, int val) {
//cout << "updating i=" << i << ", val = " << val << endl;
updateSegTree(segTree, 0, n-1, i, val);
nums[i] = val;
//checkSegTree(segTree);
}
int sumRange(int i, int j) {
return treeSum(segTree, 0, n-1, i, j);
}
private:
struct TreeNode
{
int value;
TreeNode* left;
TreeNode* right;
};
TreeNode* segTree; // segmentation tree
vector<int> nums; // input array
int n; // length of input array
TreeNode* constructSegTree(const vector<int>& nums, int st, int ed){
TreeNode *tnode = new TreeNode();
if (st == ed){
tnode->value = nums[st];
//cout << "set value to node: " << nums[st] << endl;
}else{
int mid = st + (ed-st)/2;
//cout << "mid = " << mid << endl;
tnode->left = constructSegTree(nums, st, mid);
//cout << "left val:" << tnode->left->value << endl;
tnode->right = constructSegTree(nums, mid+1, ed);
//cout << "right val:" << tnode->right->value << endl;
tnode->value = tnode->left->value + tnode->right->value;
//cout << "value= " << tnode->value << endl;
}
return tnode;
}
int treeSum(TreeNode* segTree, int st, int ed, int l, int r){
if (st >= l && ed <= r){
return segTree->value;
}else if (ed < l || st > r){
return 0;
}else{
int mid = st + (ed-st)/2;
return treeSum(segTree->left, st, mid, l, r) + treeSum(segTree->right, mid+1, ed, l, r);
}
}
void updateSegTree(TreeNode* segTree, int st, int ed, int i, int val){
int mid = st + (ed-st)/2;
int diff = val - nums[i];
//cout << "diff=" << diff << endl;
//cout << "st, ed = " << st << ", " << ed << endl;
if (st == ed){
segTree->value += diff;
}else if (i <= mid){
segTree->value += diff;
updateSegTree(segTree->left, st, mid, i, val);
}else{
segTree->value += diff;
updateSegTree(segTree->right, mid+1, ed, i, val);
}
}
// print segTree level by level (debug only)
void checkSegTree(TreeNode* segTree){
//cout << "Printing segTree structure" << endl;
queue<TreeNode*> q1;
queue<TreeNode*> q2;
q1.push(segTree);
while (!q1.empty()){
while (!q1.empty()){
TreeNode* tmp = q1.front();
q1.pop();
if (tmp->left){ q2.push(tmp->left);}
if (tmp->right){ q2.push(tmp->right);}
}
q1 = q2;
q2 = queue<TreeNode*>();
}
}
};
/**
* Your NumArray object will be instantiated and called as such:
* NumArray obj = new NumArray(nums);
* obj.update(i,val);
* int param_2 = obj.sumRange(i,j);
*/

Code (Python):

class TreeNode(object):
def __init__(self, val=0):
self.value = val
self.left = None
self.right = None
class NumArray(object):
def __init__(self, nums):
"""
:type nums: List[int]
"""
self.nums = nums
self.n = len(nums)
if self.n == 0:
return
self.seg_tree = self.construct_seg_tree(0, self.n-1)
#self.print_tree()
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: void
"""
self.update_seg_tree(self.seg_tree, 0, self.n-1, i, val)
self.nums[i] = val
def sumRange(self, i, j):
"""
:type i: int
:type j: int
:rtype: int
"""
return self.tree_sum(self.seg_tree, 0, self.n-1, i, j)
def construct_seg_tree(self, st, ed):
tmp = TreeNode()
mid = st + (ed-st)/2
if st == ed:
tmp.value = self.nums[st]
else:
tmp.left = self.construct_seg_tree(st, mid)
tmp.right = self.construct_seg_tree(mid+1, ed)
tmp.value = tmp.right.value + tmp.left.value
return tmp
def tree_sum(self, seg_tree, st, ed, i, j):
if st>=i and ed <=j:
return seg_tree.value
elif ed < i or st > j:
return 0
else:
mid = st + (ed-st)/2
return self.tree_sum(seg_tree.left, st, mid, i, j) + self.tree_sum(seg_tree.right, mid+1, ed, i, j)
def update_seg_tree(self, seg_tree, st, ed, i, val):
mid = st + (ed-st)/2
diff = val - self.nums[i]
if st==ed:
seg_tree.value += diff
else:
if i <= mid:
seg_tree.value += diff
self.update_seg_tree(seg_tree.left, st, mid, i, val)
else:
seg_tree.value += diff
self.update_seg_tree(seg_tree.right, mid+1, ed, i, val)
def print_tree(self):
node = self.seg_tree
q1 = []
q2 = []
q1.append(node)
while q1:
while q1:
tmp = q1.pop(0)
print tmp.value,
print ", ",
if tmp.left:
q2.append(tmp.left)
if tmp.right:
q2.append(tmp.right)
q1 = q2
q2 = []
print
# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(i,val)
# param_2 = obj.sumRange(i,j)

3 comments: