menu Shadowice!
#3340. 「NOI2020」命运
157 浏览 | 2020-08-27 | 阅读时间: 约 1 分钟 | 分类: ACM | 标签:
请注意,本文编写于 54 天前,最后修改于 42 天前,其中某些信息可能已经过时。

一句话题意

树上有若干条直路径,现在割掉一些边,求割断所有输入路径的方案数

dp方程

我们认为一条路径的开头是深度大的点,结尾是深度小的点

一开始想在第二维存线头长度,因为割掉了短线头长线头也一并割掉了

后来发现这样设计转移时下标会动,不方便优化,所以重新设计了下状态

$dp(u,i)$表示从子树u内部未被割断的点出发,结尾点最大深度为i时的方案数

$ans(u)$表示子树u内部所有路径均被割断的方案数

$$pre(u,i)=\sum_{j \leq i}dp(u,j)$$

转移方程

$$dp(u,i)=(ans(v)+\sum_{j}dp(v,j)) \times dp(u,i) \\ +(ans(v)+pre(v,i)) \times dp(u,i) \\ +(ans(u)+pre(u,i-1)) \times dp(v,i)$$

线段树合并优化

可以很明确的看到新数组第i位由dp(u,i)和dp(v,i)乘上各自的系数得到

使用线段树合并优化它的转移,在发现只来自u或者只来自v的子树时,通过观察方程我们发现这一段区间内系数不变,因此可以打乘法标记,所以实现一个区间乘区间求和的线段树就可以了

建议线段树合并的时候写可持久化版的,bug少

代码

#include<cstdio>
#include<algorithm>
using namespace std;const int N=5*1e5+10;
typedef long long ll;const ll mod=998244353;
ll po(ll a,ll p)
{
    ll r=1;
    for(;p;p>>=1,a=a*a%mod)(r*=((p&1)?a:1))%=mod;
    return r;
}
int n;int m;
int al[N];int v[N<<1];int x[N<<1];int ct;
void add(int u,int V)
{
  v[++ct]=V;x[ct]=al[u];al[u]=ct; 
}
int dep[N];int max_dep[N];
void predfs(int u,int f)
{
    for(int i=al[u];i;i=x[i])
    {
    if(v[i]==f)continue;
        dep[v[i]]=dep[u]+1;
        predfs(v[i],u);
    }
}
//ll dp[N][N];ll pre[N][N];
ll ans[N];
struct linetree
{
    ll sum[N*30];ll mrk[N*30];int s[N*30][2];int ct;
    int create_son(int& p)
    {
        if(!p)p=++ct,mrk[p]=1;return p;
    }
    int new_node()
    {
        mrk[++ct]=1;return ct;
    }
    void pushdown(int p)
    {
        int ls=s[p][0];if(ls)(sum[ls]*=mrk[p])%=mod,(mrk[ls]*=mrk[p])%=mod;
        int rs=s[p][1];if(rs)(sum[rs]*=mrk[p])%=mod,(mrk[rs]*=mrk[p])%=mod;
        mrk[p]=1;
    }
    void update(int p)
    {
        sum[p]=0;
        int& ls=s[p][0];if(ls)(sum[p]+=sum[ls])%=mod;
        int& rs=s[p][1];if(rs)(sum[p]+=sum[rs])%=mod;
    }
    void insert(int p,int l,int r,int pos,ll va)
    {
        if(r-l==1)
        {
            sum[p]=va;return;
        }
        int mid=(l+r)>>1;
        if(pos<mid)insert(create_son(s[p][0]),l,mid,pos,va);
        else insert(create_son(s[p][1]),mid,r,pos,va);
        update(p);
    }
/*    void output(int p,int l,int r,ll su)
    {
        if(r-l==1)
        {
            suf[++siz]=(su+sum[p])%mod;
            dp[siz]=sum[p];
            pos[siz]=l;
            return;
        }
        int mid=(l+r)>>1;
        pushdown(p);
        if(s[p][0])output(s[p][0],l,mid,su);
        if(s[p][1])output(s[p][1],mid,r,su); 
    }
*/
    // p1=u,p2=v
    int merge(int p1,int p2,int l,int r,ll preu,ll prev)
    {
        if(r-l==1)
        {
            int p=new_node();
            sum[p]=(sum[p1]*(preu+sum[p2])%mod+sum[p2]*prev)%mod;
            return p;
        }
        int mid=(l+r)>>1;pushdown(p1);pushdown(p2);
        int& ls1=s[p1][0];int& ls2=s[p2][0];
        int& rs1=s[p1][1];int& rs2=s[p2][1];    
        int p=new_node();
        if(ls1&&ls2)
        {
            s[p][0]=merge(ls1,ls2,l,mid,preu,prev);
            (preu+=sum[ls2])%=mod;
            (prev+=sum[ls1])%=mod;
        }
        else 
        {
            (preu+=sum[ls2])%=mod;
            (prev+=sum[ls1])%=mod;
            if(ls1)(mrk[ls1]*=preu)%=mod,(sum[ls1]*=preu)%=mod,s[p][0]=ls1;
            if(ls2)(mrk[ls2]*=prev)%=mod,(sum[ls2]*=prev)%=mod,s[p][0]=ls2;    
        }
        if(rs1&&rs2)
            s[p][1]=merge(rs1,rs2,mid,r,preu,prev);
        else 
        {
            if(rs1)(mrk[rs1]*=preu)%=mod,(sum[rs1]*=preu)%=mod,s[p][1]=rs1;
            if(rs2)(mrk[rs2]*=prev)%=mod,(sum[rs2]*=prev)%=mod,s[p][1]=rs2;
        }
        update(p);
        return p;
    }
    
    void print(int p,int l,int r,int id)
    {
        // printf("%d.[%d,%d]=%lld\n",id,l,r,sum[p]);
        if(r-l==1)
        {
            printf("[%d]=%lld,",l,sum[p]);
            return;
        }
        pushdown(p);
        int mid=(l+r)>>1;
        int& ls=s[p][0];int& rs=s[p][1];
        if(ls)print(ls,l,mid,id);
        if(rs)print(rs,mid,r,id);
        return;
    }
       void destroy(int p,int l,int r,int dr)
    {
    //    printf("destroy:[%d,%d],%d\n",l,r,dr);
        if(r-l==1)return;
        int mid=(l+r)>>1;pushdown(p);
        if(mid<dr)destroy(s[p][1],mid,r,dr);
        else s[p][1]=0,destroy(s[p][0],l,mid,dr);
        update(p);
    }    
}lt;int rt[N];


void dfs(int u,int f)
{
   /* if(max_dep[u]==0)ans[u]=1;
    else dp[u][max_dep[u]]=1;
    for(int d=1;d<=n;d++)(pre[u][d]=dp[u][d]+pre[u][d-1])%=mod;
    for(int i=al[u];i;i=x[i])
    {
        int ve=v[i];
        if(ve==f)continue;
        dfs(ve,u);  
        for(int d=1;d<dep[u];d++)
        {
            ll tmp=0;
            (tmp+=dp[u][d]*(pre[ve][n]+ans[ve]))%=mod;
            (tmp+=dp[u][d]*(pre[ve][d]+ans[ve]))%=mod;
            (tmp+=dp[ve][d]*(pre[u][d-1]+ans[u]))%=mod;
            dp[u][d]=tmp;
        }
        (ans[u]*=(ans[ve]+ans[ve]+pre[ve][n]))%=mod;
        for(int d=1;d<=n;d++)
            (pre[u][d]=dp[u][d]+pre[u][d-1])%=mod;
    }
    */
    if(max_dep[u]==0)ans[u]=1;
    else lt.insert(rt[u],1,n+1,max_dep[u],1);
    for(int i=al[u];i;i=x[i])
    {
        int ve=v[i];
        if(ve==f)continue;
        dfs(ve,u);
        ll preu=lt.sum[rt[ve]];
        (preu+=ans[ve]*2%mod)%=mod;
        ll prev=ans[u];
        rt[u]=lt.merge(rt[u],rt[ve],1,n+1,preu,prev);
        (ans[u]*=preu)%=mod;
        lt.destroy(rt[u],1,n+1,dep[u]);
    }
/*    printf("dp[%d]:",u);
    lt.print(rt[u],1,n+1,u);
    printf("\n");
*/
          return;
}
int main()
{
    freopen("destiny.in","r",stdin);
    freopen("destiny.out","w",stdout);
    scanf("%d",&n);
    for(int i=1,u,V;i<n;i++)
        scanf("%d%d",&u,&V),add(u,V),add(V,u);
    dep[1]=1;predfs(1,0);
    for(int i=1;i<=n;i++)max_dep[i]=0;
    scanf("%d",&m);
    for(int i=1,u,V;i<=m;i++)
        scanf("%d%d",&u,&V),max_dep[V]=max(max_dep[V],dep[u]);
  // for(int i=1;i<=n;i++)printf("%d ",dep[i]);printf("\n");
  // for(int i=1;i<=n;i++)printf("%d ",max_dep[i]);printf("\n");
   for(int i=1;i<=n;i++)rt[i]=lt.new_node();
       dfs(1,0);
    // for(int i=1;i<=n;i++)
    // {
     //    for(int j=1;j<=n;j++)
     //        printf("%lld ",dp[i][j]);printf("\n");
    // }
  // for(int i=1;i<=n;i++)printf("%lld ",ans[i]);
    printf("%lld",ans[1]);
    return 0;
}
知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议

发表评论

email
web

全部评论 (共 1 条评论)

    2020-09-23 18:13
    前排打call