0

有没有办法使用分段树结构来计算数组中给定值的频率?

假设有一个大小为 N 的数组 A,并且数组的每个元素 A[i] 都包含值 0、1 或 2。我想执行以下操作:

  • 计算数组的任何范围 [a,b] 中的零数量
  • 递增(mod 3)数组的任何范围 [a,b] 中的每个元素

示例:如果 A = [0,1,0,2,0]:

  • Query[2,4] 必须返回 1 ,因为 [2,4] 范围内有一个 0
  • Increment[2,4] 将 A 更新为 [0,2,1,0,0]

这看起来与 Range Sum Query 问题非常相似,可以使用 Segment Trees 来解决(在这种情况下,由于范围更新,使用 Lazy Propagation),但是我没有成功地将我的 seg 树代码调整到这个问题,因为如果我存储树中的值就像在正常的 RSQ 中一样,任何包含值“3”(例如)的父节点都没有任何意义,因为有了这些信息,我无法提取该范围内存在多少零。

提前致谢!

--

编辑:

段树是在其节点中存储与数组相关的区间的二叉树结构。叶节点存储实际的数组单元,每个父节点存储其子节点的函数 f(node->left, node->right)。段树通常用于执行范围求和查询,其中我们想要计算数组范围 [a,b] 中所有元素的总和。在这种情况下,父节点计算的函数是其子节点中值的总和。我们想使用 segtrees 来解决 Range Sum Query 问题,因为它允许在 O(log n) 内解决它(我们只需要下降树,直到找到完全被我们的范围查询覆盖的节点),比朴素的 O(n) 算法。

4

2 回答 2

1

由于实际的数组值存储在叶子(L 级)中,让 L - 1 级的节点存储它们包含多少个零(这将是 [0, 2] 范围内的值)。除此之外,一切都是一样的,其余节点将计算 f(node->left, node->right) 作为node->left + node->right并且零的计数将传播到根。

增加范围后,如果该范围不包含零,则无需执行任何操作。但是,如果该范围有零,那么所有这些零现在都将是 1,并且当前节点的函数值(称为 F)现在变为零。现在需要将值的变化向上传播到根,每次从函数值中减去 F。

于 2015-12-30T09:45:50.183 回答
0

使用平方根分解可以很容易地解决这个问题首先创建新的前缀和数组,将每个前缀和除以 3。将整个数组划分为 sqrt(n) 个块。每个块将有 0、1 和 2 的数量。还要创建一个临时数组,其中包含要添加到块元素中的总和这是 c++ 中的实现:

#include <bits/stdc++.h>
using namespace std;
#define si(a) scanf("%d",&a)
#define sll(a) scanf("%lld",&a)
#define sl(a) scanf("%ld",&a)
#define pi(a) printf("%d\n",a)
#define pl(a) printf("%ld\n",a)
#define pll(a) printf("%lld\n",a) 
#define sc(a) scanf("%c",&a)
#define pc(a) printf("%c",a)
#define ll long long
#define mod 1000000007
#define w while
#define pb push_back
#define mp make_pair
#define f first
#define s second
#define INF INT_MAX
#define fr(i,a,b) for(int i=a;i<=b;i++)



///////////////////////////////////////////////////////////////
struct block
{
    int one;
    int two;
    int zero;
    block()
    {
        one=two=zero=0;
    }
};
ll a[100005],a1[100005];
ll sum[400];
int main()
{
    int n,m;
    cin>>n>>m;
    string s;
    cin>>s;
    int N=(int)(sqrt(n));
    struct block b[N+10];
    for(int i=0;i<n;i++)
    {
        a[i]=s[i]-'0';
        a[i]%=3;
        a1[i]=a[i];
    }
    for(int i=1;i<n;i++)
    a[i]=(a[i]+a[i-1])%3;
    for(int i=0;i<n;i++)
    {
        if(a[i]==0)
        b[i/N].zero++;
        else if(a[i]==1)
        b[i/N].one++;
        else
        b[i/N].two++;
    }
    w(m--)
    {
        int type;
        si(type);
        if(type==1)
        {
            int ind,x;
            si(ind);
            si(x);
            x%=3;
            ind--;
                int diff=(x-a1[ind]+3)%3;
                if(diff==1)
                {
                    int st=ind/N;
                    int end=(n-1)/N;
                    int kl=(st+1)*N;
                    int hj=min(n,kl);
                    for(int i=st*N;i<hj;i++)
                    {
                        a[i]=(a[i]+sum[st])%3;
                    }
                    sum[st]=0;
                    for(int i=ind;i<hj;i++)
                    {
                        if(a[i]==0)
                        b[st].zero--;
                        else if(a[i]==1)
                        b[st].one--;
                        else
                        b[st].two--;


                        a[i]=(a[i]+diff)%3;



                        if(a[i]==0)
                        b[st].zero++;
                        else if(a[i]==1)
                        b[st].one++;
                        else
                        b[st].two++;
                    }

                    for(int i=st+1;i<=end;i++)
                    {
                        int yu=b[i].zero;
                        b[i].zero=b[i].two;
                        b[i].two=b[i].one;
                        b[i].one=yu;
                        sum[i]=(sum[i]+diff)%3;
                    }
                }
                else if(diff==2)
                {


                    int st=ind/N;
                    int end=(n-1)/N;
                    int kl=(st+1)*N;
                    int hj=min(n,kl);
                    for(int i=st*N;i<hj;i++)
                    {
                        a[i]=(a[i]+sum[st])%3;
                    }
                    sum[st]=0;
                    for(int i=ind;i<hj;i++)
                    {
                        if(a[i]==0)
                        b[st].zero--;
                        else if(a[i]==1)
                        b[st].one--;
                        else
                        b[st].two--;


                        a[i]=(a[i]+diff)%3;



                        if(a[i]==0)
                        b[st].zero++;
                        else if(a[i]==1)
                        b[st].one++;
                        else
                        b[st].two++;
                    }

                    for(int i=st+1;i<=end;i++)
                    {
                        int yu=b[i].zero;
                        b[i].zero=b[i].one;
                        b[i].one=b[i].two;
                        b[i].two=yu;
                        sum[i]=(sum[i]+diff)%3;
                    }
                }

            a1[ind]=x%3;
        }
        else
        {
            int l,r;
            ll x=0,y=0,z=0;
            si(l);
            si(r);
            l--;
            r--;
            int st=l/N;
            int end=r/N;
            if(st==end)
            {
                for(int i=l;i<=r;i++)
                {
                    ll op=(a[i]+sum[i/N])%3;
                    if(op==0)
                    x++;
                    else if(op==1)
                    y++;
                    else 
                    z++;
                }
            }
            else
            {
                for(int i=l;i<(st+1)*N;i++)
                {
                    ll op=(a[i]+sum[i/N])%3;
                    if(op==0)
                    x++;
                    else if(op==1)
                    y++;
                    else 
                    z++;
                }
                for(int i=end*N;i<=r;i++)
                {
                    ll op=(a[i]+sum[i/N])%3;
                    if(op==0)
                    x++;
                    else if(op==1)
                    y++;
                    else 
                    z++;
                }
                for(int i=st+1;i<=end-1;i++)
                {
                    x+=b[i].zero;
                    y+=b[i].one;
                    z+=b[i].two;
                }
            }
            ll temp=0;
            if(l!=0)
            {
                temp=(a[l-1]+sum[(l-1)/N])%3;
            }
            ll ans=(x*(x-1))/2;
            ans+=((y*(y-1))/2);
            ans+=((z*(z-1))/2);
            if(temp==0)
            ans+=x;
            else if(temp==1)
            ans+=y;
            else
            ans+=z;
            pll(ans);
        }
    }
    return 0;
}
于 2020-04-12T11:32:34.730 回答