public class SegmentTree<E> {
private E[] data;
private E[] tree;
private Merger<E> merger;
public SegmentTree(E[] arr,Merger<E> merger){
this.merger = merger;
int n = arr.length;
data = (E[])new Object[arr.length];
for (int i = 0; i < n; i++) {
data[i] = arr[i];
}
tree = (E[])new Object[4 * arr.length];
buildSegmentTree(0, 0, arr.length - 1);
}
public E query(int l , int r){
if (l < 0 || l >= data.length || r < 0 || r >= data.length || l > r) {
throw new IllegalArgumentException("index is illegal");
}
return query(0,0,data.length-1,l,r);
}
//在l到r 的区间里 搜索queryl 到 queryr的值
private E query(int index, int l, int r, int queryL, int queryR) {
if(l==queryL && r == queryR){
return tree[index];
}
int mid = (l+r)>>1;
int leftIdx = leftChild(index);
int rightIdx = rightChild(index);
if (queryL>=mid+1){
return query(rightIdx,mid+1,r,queryL,queryR);
}else if (queryR<=mid){
return query(leftIdx,l,mid,queryL,queryR);
}
E leftRes = query(leftIdx, l, mid, queryL, mid);
E rightRes = query(rightIdx, mid+1, r, mid+1, queryR);
E merge = merger.merge(leftRes, rightRes);
return merge;
}
public void set(int idx,E e){
if (idx<0 || idx>=data.length) {
throw new IllegalArgumentException("index is illegal");
}
data[idx] = e;
set(0,0,data.length-1,idx,e);
}
//在treeIdx中为根的线段树中更新index的值为e
public void set(int treeIdx,int l ,int r,int index,E e){
if (l==r){
tree[treeIdx] = e;
}
int mid = (l+r)>>1;
int leftTreeIndex = leftChild(treeIdx);
int rightTreeIndex = rightChild(treeIdx);
if (index>=mid+1){
set(rightTreeIndex,mid+1,r,index,e);
}else {
set(leftTreeIndex,l,mid,index,e);
}
tree[treeIdx] = merger.merge(tree[leftTreeIndex],tree[rightTreeIndex]);
}
private void buildSegmentTree(int treeIndex, int l, int r){
if(l == r){
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// int mid = (l + r) / 2;
int mid = l + (r - l) / 2;
buildSegmentTree(leftTreeIndex, l, mid);
buildSegmentTree(rightTreeIndex, mid + 1, r);
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
public int getSize(){
return data.length;
}
public E get(int idx){
if (idx<0 || idx>=data.length){
throw new IllegalArgumentException("idx is illegal");
}
return data[idx];
}
private int leftChild(int idx){
return 2*idx+1;
}
private int rightChild(int idx){
return 2*idx+2;
}
@Override
public String toString(){
StringBuilder sb = new StringBuilder();
sb.append('[');
for (int i = 0; i < tree.length; i++) {
if (tree[i]!=null) {
sb.append(tree[i]);
}else {
sb.append("null");
}
if (i!=tree.length-1){
sb.append(",");
}
}
sb.append(']');
return sb.toString();
}
}
Main方法测试类
[Java] 纯文本查看复制代码
public class Main {
public static void main(String[] args) {
Integer[] nums = {-2,0,3,-5,2,-1};
SegmentTree<Integer> segTree = new SegmentTree<>(nums,
(a, b) -> a + b);
System.out.println(segTree.query(0, 2));
System.out.println(segTree.query(2, 5));
System.out.println(segTree.query(0, 5));
}
}