鞠古香 发表于 2025-6-1 00:00:14

虚树

0.前言

若干年前会过一次,然后就不会了。
现在又会了。
1.虚树

存在这样一类问题,多次询问,每次给你 \(k\) 个特殊/关键点,要求有关这 \(k\) 个点的某种信息,并且题目保证了或隐含了 \(\sum k \le lim\),那么可以考虑使用虚树解决。
虚树,即将这 \(k\) 个点提出,并按照原树上的节点关系,建立一棵大小约为 \(2k\) 的树,可以在这棵树上求答案。
这里仅介绍二次排序 + LCA 的建树方法。

[*]将关键点按照 dfs 序排序;
[*]将排序后相邻两点的 LCA 加入,再按 dfs 序排序并去重;
[*]连 \((LCA(ve_{i - 1},ve_i),ve_i)\) 的边。
对于 2:一般会将点 \(1\) 也加入,便于求答案。
因为按照 dfs 序排序了,所以相邻两点之间一定不会出现 dfs 序介于他们俩到 LCA 的路径上的点,正确性得证。
建树时间复杂度:\(\mathcal{O(\sum k \log n)}\)。
2.题目

2.1.P2495 消耗战

板子题。
我们将虚树建出,接下来考虑树形 DP。
\(f_u\) 表示点 \(u\) 与其子树内所有关键点断开所需最小代价。
对于虚树上 \((u,v)\):若 \(v\) 是关键点,那么这条边必须断;否则判断是断这条边代价更小还是在 \(v\) 的子树中断若干边代价更小。
显然答案为 \(f_1\)。
#include <bits/stdc++.h>

#define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define vi vector<int>
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define me(a,x) memset(a,x,sizeof a)
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------")

using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
    template<class T> il void read(T &x) {
      x = 0; T f = 1; char ch = getchar();
      while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
      while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
      x *= f;
    }
    template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
    template<class T> il void print(T x) {
      if (x < 0) ptc('-'), x = -x;
      if (x > 9) print(x / 10); ptc(x % 10 + '0');
    }
    template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
    template<class T,class T_> il void chmax(T &x,T_ y) { x = x < (T)y ? (T)y : x; }
    template<class T,class T_> il void chmin(T &x,T_ y) { x = x > (T)y ? (T)y : x; }
    template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
      T res = 1; while (b) {
            if (b & 1) res = res * a % p;
            a = a * a % p; b >>= 1;
      } return res;
    }
    template<class T> il T gcd(T a,T b) { if (!b) return a; return gcd(b,a % b); }
    template<class T,class T_> il void exgcd(T a, T b, T_ &x, T_ &y) {
      if (b == 0) { x = 1; y = 0; return; }
      exgcd(b,a % b,y,x); y -= a / b * x; return ;
    }
    template<class T,class T_> il T getinv(T x,T_ p) {
      T inv,y; exgcd(x,(T)p,inv,y);
      inv = (inv + p) % p; return inv;
    }
} using namespace szhqwq;
const int N = 5e5 + 10,inf = 1e9,mod = 998244353;
const ull base = 131,base_ = 233;
const ll inff = 1e18;
const db eps = 1e-6;
int n,m; vector<PII> G;
int h,e,ne,w,idx,dfn,tot;
int d,fa,dis,f;
bool st;

il void add(int a,int b,int c) {
    e = b;
    w = c;
    ne = h;
    h = idx ++;
    return ;
}

il void dfs(int u,int faa,int val) {
    dfn = ++ tot;
    fa = faa; dis = val;
    d = d + 1;
    rep(i,1,18)
      fa = fa],
      dis = min(dis,dis]);
    for (int i = h; ~i; i = ne) {
      int j = e;
      if (j == faa) continue;
      dfs(j,u,w);
    }
    return ;
}

il int LCA(int x,int y) {
    if (d < d) swap(x,y);
    rep1(i,18,0) if (d] >= d) x = fa;
    if (x == y) return x;
    rep1(i,18,0) if (fa != fa) x = fa,y = fa;
    return fa;
}

il int calc(int x,int p) {
    int ret = inf;
    rep1(i,18,0) if (d] >= d)
      chmin(ret,dis),x = fa;
    return ret;
}

il void dfss(int u,int faa) {
    f = 0;
    for (auto x : G) {
      int v = x.fst,w = x.snd;
      if (v == faa) continue;
      dfss(v,u);
      if (st) f += w;
      else f += min(f,w);
    }
    return ;
}

il void solve() {
    //------------code------------
    me(h,-1); me(dis,0x3f);
    read(n);
    rep(i,1,n - 1) {
      int a,b,c; read(a,b,c);
      add(a,b,c); add(b,a,c);
    }
    dfs(1,0,inf);
    read(m);
    while (m -- ) {
      int k; read(k);
      vi v,ve;
      rep(i,1,k) {
            int x; read(x);
            v.pb(x);
      }
      sort(all(v),[](int x,int y){ return dfn < dfn; });
      ve.pb(1); ve.pb(v);
      rep(i,1,k - 1)
            ve.pb(LCA(v,v)),
            ve.pb(v);
      sort(all(ve),[](int x,int y){ return dfn < dfn; });
      ve.erase(unique(all(ve)),ve.end());
      rep(i,0,sz(ve) - 1) st] = 0,G].clear();
      for (auto x : v) st = 1;
      rep(i,1,sz(ve) - 1) {
            int lca = LCA(ve,ve);
            G.pb(ve,calc(ve,lca));
      }
      dfss(1,0); write(f,'\n');
    }
    return ;
}

il void init() {
    return ;
}

signed main() {
    // init();
    int _ = 1;
    // read(_);
    while (_ -- ) solve();
    return 0;
}2.2.CF613D Kingdom and its Cities

同样建出虚树。要使所有关键点两两不连通,因为点权相同,可以直接贪心。
先判掉无解的情况。
若当前点 \(u\) 是关键点且其子树中存在其他关键点,那么子树中有多少关键点答案就加多少;
否则看子树中有多少关键点,若有 \(> 1\) 个,那么占领 \(u\) 点一定不劣;如果仅有 \(1\) 个,考虑放到父亲节点及上面进行处理一定更优。
#include <bits/stdc++.h>

// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define vi vector<int>
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define me(a,x) memset(a,x,sizeof a)
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------")

using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
    template<class T> il void read(T &x) {
      x = 0; T f = 1; char ch = getchar();
      while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
      while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
      x *= f;
    }
    template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
    template<class T> il void print(T x) {
      if (x < 0) ptc('-'), x = -x;
      if (x > 9) print(x / 10); ptc(x % 10 + '0');
    }
    template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
    template<class T,class T_> il void chmax(T &x,T_ y) { x = x < (T)y ? (T)y : x; }
    template<class T,class T_> il void chmin(T &x,T_ y) { x = x > (T)y ? (T)y : x; }
    template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
      T res = 1; while (b) {
            if (b & 1) res = res * a % p;
            a = a * a % p; b >>= 1;
      } return res;
    }
    template<class T> il T gcd(T a,T b) { if (!b) return a; return gcd(b,a % b); }
    template<class T,class T_> il void exgcd(T a, T b, T_ &x, T_ &y) {
      if (b == 0) { x = 1; y = 0; return; }
      exgcd(b,a % b,y,x); y -= a / b * x; return ;
    }
    template<class T,class T_> il T getinv(T x,T_ p) {
      T inv,y; exgcd(x,(T)p,inv,y);
      inv = (inv + p) % p; return inv;
    }
} using namespace szhqwq;
const int N = 2e5 + 10,inf = 1e9,mod = 998244353;
const ull base = 131,base_ = 233;
const ll inff = 1e18;
const db eps = 1e-6;
int n,q,id,cnt,fa;
int h,e,ne,idx,d;
bool vis;
vi G;

il void add(int a,int b) {
    e = b;
    ne = h;
    h = idx ++;
    return ;
}

il void dfs(int u,int f) {
    fa = f;
    d = d + 1;
    rep(i,1,18) fa = fa];
    id = ++ cnt;
    for (int i = h; ~i; i = ne) {
      int j = e;
      if (j == f) continue;
      dfs(j,u);
    }
    return ;
}

il int LCA(int x,int y) {
    if (d < d) swap(x,y);
    rep1(i,18,0) if (d] >= d) x = fa;
    if (x == y) return x;
    rep1(i,18,0) if (fa != fa) x = fa,y = fa;
    return fa;
}

int ret = 0,st;
il void calcans(int u) {
    if (vis) st = 1;
    int tot = 0;
    for (auto v : G) {
      calcans(v);
      tot += st;
    }
    if (vis && tot) ret += tot;
    else if (tot == 1) st = 1;
    else if (tot > 1) ++ ret;
    return ;
}

il void solve() {
    //------------code------------
    read(n); me(h,-1);
    rep(i,1,n - 1) {
      int a,b; read(a,b);
      add(a,b); add(b,a);
    }
    dfs(1,0);
    read(q);
    while (q -- ) {
      int k; read(k);
      vi v,ve;
      rep(i,1,k) {
            int x; read(x);
            v.pb(x);
      }
      sort(all(v),[](int x,int y){ return id < id; });
      ve.pb(1); ve.pb(v);
      rep(i,1,sz(v) - 1)
            ve.pb(LCA(v,v)),
            ve.pb(v);
      sort(all(ve),[](int x,int y){ return id < id; });
      ve.erase(unique(all(ve)),ve.end());
      for (auto x : ve) G.clear(),vis = vis] = 0,st = 0;
      for (auto x : v) vis = 1;
      bool fl = 1;
      for (auto x : v) if (vis]) { fl = 0; break; }
      if (!fl) { puts("-1"); continue; }
      rep(i,1,sz(ve) - 1) G,ve)].pb(ve);
      ret = 0;
      calcans(1);
      write(ret,'\n');
    }
    return ;
}

il void init() {
    return ;
}

signed main() {
    // init();
    int _ = 1;
    // read(_);
    while (_ -- ) solve();
    return 0;
}2.3.P3233 世界树

建虚树,考虑 up and down DP。
\(f_u = (dist,id)\),分别表示最短距离及编号。
分别向上向下更新信息。
考虑最后计算答案怎么做。
对于虚树上 \((u,v)\),令 \(p\) 为 \(u\) 在原树上和 \(v\) 那条链上的儿子。
如果两点所属的关键点相同,则 \(cnt_{id} \gets siz_p - siz_v\)。
若不同,则倍增跳到分界点,即上半部分所属点与 \(u\) 相同,下半部分所属点与 \(v\) 相同的点 \(mid\),\(mid\) 自己和 \(y\) 相同。
\(cnt_{id_u} \gets siz_p - siz_{mid},cnt_{id_v} \gets siz_{mid} - siz_v\)。
注意到 \(u\) 某些子树中可能不存在关键点,那么这些子树中的点显然所属与 \(u\) 一致,处理一下即可。
#include <bits/stdc++.h>

// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define vi vector<int>
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define me(a,x) memset(a,x,sizeof a)
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------")

using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
    template<class T> il void read(T &x) {
      x = 0; T f = 1; char ch = getchar();
      while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
      while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
      x *= f;
    }
    template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
    template<class T> il void print(T x) {
      if (x < 0) ptc('-'), x = -x;
      if (x > 9) print(x / 10); ptc(x % 10 + '0');
    }
    template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
    template<class T,class T_> il void chmax(T &x,T_ y) { x = x < (T)y ? (T)y : x; }
    template<class T,class T_> il void chmin(T &x,T_ y) { x = x > (T)y ? (T)y : x; }
    template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
      T res = 1; while (b) {
            if (b & 1) res = res * a % p;
            a = a * a % p; b >>= 1;
      } return res;
    }
    template<class T> il T gcd(T a,T b) { if (!b) return a; return gcd(b,a % b); }
    template<class T,class T_> il void exgcd(T a, T b, T_ &x, T_ &y) {
      if (b == 0) { x = 1; y = 0; return; }
      exgcd(b,a % b,y,x); y -= a / b * x; return ;
    }
    template<class T,class T_> il T getinv(T x,T_ p) {
      T inv,y; exgcd(x,(T)p,inv,y);
      inv = (inv + p) % p; return inv;
    }
} using namespace szhqwq;
const int N = 3e5 + 10,inf = 1e9,mod = 998244353;
const ull base = 131,base_ = 233;
const ll inff = 1e18;
const db eps = 1e-6;
int n,q,d,fa,siz;
int h,e,ne,idx;
PII f; int id,cnt;
vi G; bool vis;
int ret;

il void add(int a,int b) {
    e = b;
    ne = h;
    h = idx ++;
    return ;
}

il void dfs(int u,int f) {
    d = d + 1;
    fa = f; id = ++ cnt;
    rep(i,1,18) fa = fa];
    siz = 1;
    for (int i = h; ~i; i = ne) {
      int j = e;
      if (j == f) continue;
      dfs(j,u);
      siz += siz;
    }
    return ;
}

il int LCA(int x,int y) {
    if (d < d) swap(x,y);
    rep1(i,18,0) if (d] >= d) x = fa;
    if (x == y) return x;
    rep1(i,18,0) if (fa != fa) x = fa,y = fa;
    return fa;
}

il int getdis(int x,int y) {
    return d + d - 2 * d;
}

il void dfs1(int u) {
    f = {inf,0};
    if (vis) f = {0,u};
    for (auto v : G) {
      dfs1(v);
      if (f.snd) {
            int val = getdis(u,f.snd);
            if (val < f.fst) f = {val,f.snd};
            else if (val == f.fst) chmin(f.snd,f.snd);
      }
    }
    return ;
}

il void dfs2(int u) {
    for (auto v : G) {
      if (f.snd) {
            int val = getdis(v,f.snd);
            if (val < f.fst) f = {val,f.snd};
            else if (val == f.fst) chmin(f.snd,f.snd);
      }
      dfs2(v);
    }
    return ;
}

il void calcans(int u) {
    int val = siz;
    for (auto v : G) {
      int p = v;
      rep1(i,18,0) if (d] > d) p = fa;
      val -= siz;
      if (f.snd == f.snd) ret.snd] += siz - siz;
      else {
            int mid = v;
            rep1(i,18,0) {
                int dis = getdis(f.snd,fa),diss = getdis(f.snd,fa);
                if (dis > diss || dis == diss && f.snd < f.snd) mid = fa;
            }
            ret.snd] += siz - siz; ret.snd] += siz - siz;
      }
      calcans(v);
    }
    ret.snd] += val;
    return ;
}

il void solve() {
    //------------code------------
    read(n); me(h,-1);
    rep(i,1,n - 1) {
      int a,b; read(a,b);
      add(a,b); add(b,a);
    }
    dfs(1,0);
    read(q);
    while (q -- ) {
      int m; read(m);
      vi v,ve,vec;
      rep(i,1,m) {
            int x; read(x);
            v.pb(x); vec.pb(x);
            ret = 0;
      }
      sort(all(v),[](int x,int y){ return id < id; });
      ve.pb(1); ve.pb(v);
      rep(i,1,sz(v) - 1)
            ve.pb(LCA(v,v)),
            ve.pb(v);
      sort(all(ve),[](int x,int y){ return id < id; });
      ve.erase(unique(all(ve)),ve.end());
      for (auto x : ve) G.clear(),vis = 0;
      for (auto x : v) vis = 1;
      rep(i,1,sz(ve) - 1) G,ve)].pb(ve);
      dfs1(1); dfs2(1); calcans(1);
      for (auto x : vec) write(ret,' ');
      puts("");
    }
    return ;
}

il void init() {
    return ;
}

signed main() {
    // init();
    int _ = 1;
    // read(_);
    while (_ -- ) solve();
    return 0;
}
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
页: [1]
查看完整版本: 虚树