1758. 树上莫队1

莫队

复杂度 O(n\sqrt n)

#include <bits/stdc++.h> using namespace std; #define sc(x) scanf("%lld", &x) typedef long long ll; #define mn 100010 ll n, m, sq, ans[mn], bin[mn], sum, a[mn], hd[mn], cnt; ll dfn[mn], st, mx[mn], id[mn]; struct query { ll l, r, i; bool operator<(const query &x) const { if (l / sq != x.l / sq) return l < x.l; return l / sq & 1 ? r < x.r : r > x.r; } } q[mn]; struct edge { ll to, nx; } e[mn * 2]; signed main() { sc(n), sq = sqrt(n); auto adde = [&](ll u, ll v) { e[++cnt] = {v, hd[u]}, hd[u] = cnt; }; for (ll i = 1, u, v; i < n; ++i) { sc(u), sc(v), adde(u, v), adde(v, u); } for (ll i = 1; i <= n; ++i) { sc(a[i]); } auto dfs = [&](auto self, ll u, ll fa) -> ll { dfn[u] = ++st, mx[u] = st, id[st] = a[u]; for (ll i = hd[u], v; i; i = e[i].nx) { if ((v = e[i].to) != fa) { mx[u] = max(mx[u], self(self, v, u)); } } return mx[u]; }; dfs(dfs, 1, 0); sc(m); for (ll i = 1, x; i <= m; ++i) { sc(x), q[i].l = dfn[x], q[i].r = mx[x], q[i].i = i; } sort(q + 1, q + 1 + m); auto add = [&](ll i) { sum += bin[id[i]]++ == 0; }; auto del = [&](ll i) { sum -= --bin[id[i]] == 0; }; for (ll i = 1, l = 1, r = 0; i <= n; ++i) { //不要像我一样r不初始化犯低级错误DEBUG半个钟QwQ, r=0!r=0!r=0! while (l > q[i].l) add(--l); while (r < q[i].r) add(++r); while (l < q[i].l) del(l++); while (r > q[i].r) del(r--); ans[q[i].i] = sum; } for (ll i = 1; i <= m; ++i) { printf("%lld\n", ans[i]); } return 0; }

树上启发式合并

复杂度 O(n\log n)

#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef double db; #define sc(x) scanf("%lld", &x) #define mn 100010 ll n, m, cnt, dfn[mn], big[mn], siz[mn], tot, c[mn], ecnt, hd[mn]; ll lf[mn], rf[mn], s[mn], ans[mn], x; struct edge { ll to, nx; } e[mn * 2]; void adde(ll u, ll v) { e[++ecnt] = {v, hd[u]}; hd[u] = ecnt; } void dfs1(ll u, ll fa) { lf[u] = ++cnt, dfn[cnt] = u, siz[u] = 1; for (ll i = hd[u], v; i; i = e[i].nx) { v = e[i].to; if (v == fa) { continue; } dfs1(v, u); siz[u] += siz[v]; if (siz[big[u]] < siz[v]) { big[u] = v; } } rf[u] = cnt; } void add(ll u) { tot += (s[c[u]]++ == 0); } void remove(ll u) { tot -= (--s[c[u]] == 0); } void dfs2(ll u, ll fa, bool save) { for (ll i = hd[u], v; i; i = e[i].nx) { v = e[i].to; if (v != fa && v != big[u]) { dfs2(v, u, false); } } if (big[u]) { dfs2(big[u], u, true); } for (ll i = hd[u], v; i; i = e[i].nx) { v = e[i].to; if (v != fa && v != big[u]) { for (ll j = lf[v]; j <= rf[v]; ++j) { add(dfn[j]); } } } add(u); ans[u] = tot; if (!save) { for (ll j = lf[u]; j <= rf[u]; ++j) { remove(dfn[j]); } } } signed main() { sc(n); for (ll i = 1, u, v; i < n; ++i) { sc(u), sc(v), adde(u, v), adde(v, u); } for (ll i = 1; i <= n; ++i) { sc(c[i]); } dfs1(1, 0), dfs2(1, 0, false); for (sc(m); m; --m) { sc(x); printf("%lld\n", ans[x]); } return 0; }

附赠:数据生成程序

因为洛谷这题无法下载数据,这里的数据是我用脚造的,跟洛谷原题不一样,仅供调试参考

from random import * idx, mn = 1, int(1e5) def wri(r): global idx with open('ff%d.in' % idx, 'w') as f: f.write(r) idx += 1 def getw(n, amx=mn, ami=1): r = '\n' for i in range(n): r += '%d ' % randint(ami, amx) return r[:-1] def getquery(n): r = '\n%d' % n for i in range(1, n + 1): r += '\n%d' % i return r def gen_link(n, amx=1, ami=1): r = '%d' % n for i in range(1, n): r += '\n%d %d' % (i, i + 1) return r + getw(n, amx, ami) + getquery(n) def gen_flower(n, amx=1, ami=1, rot=1): r = '%d' % n for i in range(1, n + 1): if rot != i: r += '\n%d %d' % (rot, i) return r + getw(n, amx, ami) + getquery(n) def gen_tree(n, amx=1, ami=1): r = '%d' % n for i in range(1, n): r += '\n%d %d' % (i, randint(i + 1, n)) return r + getw(n, amx, ami) + getquery(n) for mt in (1, 10, 100, mn): wri(gen_tree(2, mn, mt)) wri(gen_tree(10, mt)) wri(gen_link(mn, mt)) wri(gen_flower(mn, mt)) wri(gen_flower(mn, mt, rot=580)) wri(gen_tree(mn, mt))