线段树

我杀文化课

2019-10-05 21:33:10

Solution

更新:(之前图挂了,现在重新上传一下) ------------ 今天本蒟蒻终于学了线段树,我觉得有必要来讲解一下这个极好~~毒瘤~~的算法。 ## 那么,什么是线段树呢? 线段树(Interval Tree)。又名为区间树,每个结点代表的是一个区间,如图表示一个1~10区间的线段树结构。 ![](https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=2676049308,1082598863&fm=26&gp=0.jpg) 一般来说,它可以做: 1、查询某个区间 2、修改某个区间或点 是不是觉得可以干的事很少? ~~(其实我也觉得)~~ NO!它的用处很大,对于所有区间的问题,他几乎都可以解决!并且还挺快。 好了,我们先来康康这个题。 [洛谷P3372 线段树模板](https://www.luogu.org/problem/P3372) 首先我们分析一下:题目要求我们进行区间修改+区间查询,看看上面线段树的图,是不是觉得这个算法就是为这个题而生的?!! ~~不,其实是这个题为这个算法而生的~~ 那么第一个问题来了,如何建树?我们可以清楚地通过上图发现线段树的每个子节点都是由它的父亲节点**二分**而来的,再通过二叉树的性质 **我们可以知道对于一个父节点i,它的左右儿子节点分别为i乘以2,i乘以2+1** 所以我们可以根据这个性质用数组来模拟线段树,代码如下。 ```cpp void build(int id,int l,int r) { tree[id].l=l; tree[id].r=r;//l,r代表这个节点所代表的区间的左端点和右端点。 if(l==r) { tree[id].sum=a[l]; return; }//l==r时说明它就是最下层叶子节点(只代表一个数),那么我们直接就将a数组(原始数列)的第l(或r)个元素映射上去。 int mid=(l+r)/2; build(id*2,l,mid); build(id*2+1,mid+1,r);//二分建树 tree[id].sum=tree[id*2].sum+tree[id*2+1].sum;//sum表示这段区间的区间和,所以这个节点的sum值等于它的儿子节点的值之和。 } ``` 现在建树问题解决了,那区间更新怎么办呢? 现在我们思考一个问题,如果我们要对这颗树上的区间加上v,是不是只用找到代表这个区间的节点(可能不止一个节点,例如[2,3]这个区间就需要查找[2,2]和[3,3]这两个区间,然后将他们的sum加起来),然后**将这个节点的sum值加上v乘以这个区间所包含的数的个数**?下次访问相同的区间的时候直接将这个数返回过去,根本就不用再向下遍历了。 但是问题又来了,上一次我对[1,2]这个区间加了v,这一次我想知道[2,3]的区间和怎么办呢?要是只用我们上面提到的方法,很明显遍历到[2,2]的时候这个点并没有加v,那返回时答案就错了呀! 怎么办? 这时延迟标记(lazy)就派上用场了,它代表我对这个区间的数加过的值,我们在对一个区间加了v之后,我们将lazy的值也加上这个v,后来我们再访问这个区间但只是访问这个区间的一部分时我们就可以知道这个父节点下面的儿子节点需要加上多少了,但我们现在不用急着将这个lazy标记给它的儿子们加上,因为这很费时,当我们下次将要访问它的儿子节点时再给他们加上lazy标记,此时这个父节点的lazy值就要清零,我们就称其为标记下传。 举个例子:我们在访问[2,3]时,我们会访问[2,2]和[3,3]这两个节点,将要访问[2,2]时,我们肯定会访问[1,2]这个点,那么在向下遍历的同时,我们就顺便将lazy标记下传,下传时就把每个叶子节点的lazy加上这个值,再按之前说到的方法更新sum。如此反复。 代码如下: ```cpp void down(int id) { tree[id*2].lazy+=tree[id].lazy; tree[id*2].sum+=(tree[id*2].r-tree[id*2].l+1)*tree[id].lazy;//下传左儿子。 tree[id*2+1].lazy+=tree[id].lazy; tree[id*2+1].sum+=(tree[id*2+1].r-tree[id*2+1].l+1)*tree[id].lazy;//下传右儿子 tree[id].lazy=0;//将标记清零 } void update(int id,int l,int r,int v) { if(tree[id].l>r || tree[id].r<l) return; //如果你要访问的区间和你遍历到的这个区间完全不重合,那么返回,相当于剪枝。 if(tree[id].l>=l && tree[id].r<=r) { tree[id].lazy+=v; tree[id].sum+=(tree[id].r-tree[id].l+1)*v; return; }//如果你要访问的区间和当前遍历到的区间重合,那么就直接把lazy和sum更新,然后返回,因为没必要再向下了! if(tree[id].lazy>0) down(id);//之前遍历过,标记下传。 update(id*2,l,r,v); update(id*2+1,l,r,v);//递归更新左右儿子 tree[id].sum=tree[id*2].sum+tree[id*2+1].sum;//递归回来后将左右儿子的sum加起来赋给父亲的sum。 } ``` 总结一下,其实就是更新就分为这么两大步: 1.判断访问区间和遍历区间的关系(相离,部分相交,完全包含) 如果相离,直接返回。 如果部分相交,那么就需要将标记下传(如果有),然后向下找完全包含的区间。 如果完全包含,更新lazy标记,并且将sum更新(因为当前不会再向下,所以不能通过左右儿子的sum来计算)要按代码那样计算(也就是将这个节点的sum值加上v * 这个区间所包含的数的个数) 在这里,我提一个问题,为什么要在更新时下传标记,貌似完全可以在查询到这个节点时下传标记啊,粗略一想,貌似挺有道理,但是仔细思考后,你会发现代码中,你更新完左右儿子回来后,会给父节点的sum值进行更新,如果按照刚才的说法,那么之前你加过的那个值就漏了,所以在更新时一定要标记下传! 2.标记下传 先下传到左儿子 再下传到右儿子 下传时你需要更新儿子节点的lazy(因为它还可能有儿子,下次遍历它的儿子时要继续下传)和它的sum(与1中的sum更新方法相同) 最后,将自己的lazy标记清空(否则会重复下传,造成答案错误) ------------ 到了最后一步了,那就是查询,其实查询和更新很相似,也需要讨论所查询的区间和你现在遍历到的区间的关系(相离,部分相交,完全包含),查询到有lazy标记时一样要下传,就不多说了。 代码如下: ```cpp long long query(int id,int l,int r) { if(tree[id].l>r || tree[id].r<l) return 0;//相离 if(tree[id].l>=l && tree[id].r<=r) return tree[id].sum;//完全包含 if(tree[id].lazy>0) down(id); return query(id*2,l,r)+query(id*2+1,l,r);//部分包含 } ``` 这就是线段树最基础的啦! 完整代码: ```cpp #include<cstdio> using namespace std; int n,m,a[100010]; struct llk { long long l,r,lazy,sum; }tree[500010]; void build(int id,int l,int r) { tree[id].l=l; tree[id].r=r; if(l==r) { tree[id].sum=a[l]; return; } int mid=(l+r)/2; build(id*2,l,mid); build(id*2+1,mid+1,r); tree[id].sum=tree[id*2].sum+tree[id*2+1].sum; } void down(int id) { tree[id*2].lazy+=tree[id].lazy; tree[id*2].sum+=(tree[id*2].r-tree[id*2].l+1)*tree[id].lazy; tree[id*2+1].lazy+=tree[id].lazy; tree[id*2+1].sum+=(tree[id*2+1].r-tree[id*2+1].l+1)*tree[id].lazy; tree[id].lazy=0; } void update(int id,int l,int r,int v) { if(tree[id].l>r || tree[id].r<l) return; if(tree[id].l>=l && tree[id].r<=r) { tree[id].lazy+=v; tree[id].sum+=(tree[id].r-tree[id].l+1)*v; return; } if(tree[id].lazy>0) down(id); update(id*2,l,r,v); update(id*2+1,l,r,v); tree[id].sum=tree[id*2].sum+tree[id*2+1].sum; } long long query(int id,int l,int r) { if(tree[id].l>r || tree[id].r<l) return 0; if(tree[id].l>=l && tree[id].r<=r) return tree[id].sum; if(tree[id].lazy>0) down(id); return query(id*2,l,r)+query(id*2+1,l,r); } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&a[i]); build(1,1,n); for(int i=1;i<=m;i++) { int o,x,y; scanf("%d",&o); if(o==1) { int k; scanf("%d%d%d",&x,&y,&k); update(1,x,y,k); } else { scanf("%d%d",&x,&y); printf("%lld\n",query(1,x,y)); } } } ```