下次做再写吧 注意hash的base不用太大不然容易挂
#include<cstdio>#include<iostream>#include<cstring>#include<cstdlib>#include<algorithm>#include<cmath>#include<map>using namespace std;typedef long long LL;const int _=1e2;const int maxn=1e5+_;const int W=97;const int P=1e9+7;int n,mw[maxn];void yu(){mw[0]=1;for(int i=1;i<maxn;i++)mw[i]=(LL)mw[i-1]*W%P;}int quick_pow(int A,int p){ int ret=1; while(p!=0) { if(p%2==1)ret=(LL)ret*A%P; A=(LL)A*A%P;p/=2; } return ret;} struct node{int x,y,next;};struct tree{ node a[2*maxn];int len,last[maxn],du[maxn]; void ins(int x,int y) { len++; a[len].x=x;a[len].y=y; a[len].next=last[x];last[x]=len; } void cop_main() { int x,y; for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); ins(x,y),ins(y,x); du[x]++,du[y]++; } } //~~~~~~~~~~~~~~maketree~~~~~~~~~~~~~~~~~~~~~ int tot[maxn],f[maxn],g[maxn]; int tp,v[maxn]; void getf1(int x,int fr) { tot[x]=1; bool isleaf=true; for(int k=last[x];k;k=a[k].next) if(a[k].y!=fr)isleaf=false,getf1(a[k].y,x),tot[x]+=tot[a[k].y]; if(isleaf){f[x]=1;return ;} tp=0; for(int k=last[x];k;k=a[k].next) if(a[k].y!=fr)v[++tp]=f[a[k].y]; sort(v+1,v+tp+1); for(int i=1;i<=tp;i++) f[x]=((LL)f[x]+(LL)v[i]*mw[i-1])%P; f[x]=(LL)f[x]*tot[x]%P; } pair<int,int>p[maxn];int pre[maxn],suf[maxn]; void getg(int x,int fr) { tp=0; if(x!=1)p[++tp]=make_pair(g[x],-1); for(int k=last[x];k;k=a[k].next) if(a[k].y!=fr)p[++tp]=make_pair(f[a[k].y],a[k].y); sort(p+1,p+tp+1); if(x!=1) { f[x]=0; for(int i=1;i<=tp;i++) f[x]=((LL)f[x]+(LL)p[i].first*mw[i-1])%P; f[x]=(LL)f[x]*n%P; } pre[0]=0;suf[tp+1]=0; for(int i=1;i<=tp;i++)pre[i]=((LL)pre[i-1]+(LL)p[i].first*mw[i-1])%P; for(int i=tp;i>=1;i--)suf[i]=((LL)suf[i+1]+(LL)p[i].first*mw[i-2])%P; for(int i=1;i<=tp;i++) if(p[i].second!=-1) { int y=p[i].second; g[y]=(LL)(n-tot[y])*(pre[i-1]+suf[i+1])%P; } if(x==1&&du[x]==1)g[p[1].second]=1; for(int k=last[x];k;k=a[k].next) if(a[k].y!=fr)getg(a[k].y,x); } void hash_main() { getf1(1,0); getg(1,0); } //~~~~~~~~~~~~~~~~~~~~hash~~~~~~~~~~~~~~~~~~~~~~~ }A,B;map<int,bool>mp;int main(){ freopen("a.in","r",stdin); freopen("a.out","w",stdout); scanf("%d",&n);yu(); A.cop_main(),A.hash_main(); n++,B.cop_main(),B.hash_main(); for(int i=1;i<n;i++)mp[A.f[i]]=true; for(int i=1;i<=n;i++) if(B.du[i]==1) { if(mp[(LL)B.f[i]*quick_pow(n,P-2)%P]==true) {printf("%dn",i);return 0;} } return 0;}