复杂度 O(n\sqrt n)
#include <bits/stdc++.h>
using namespace std;
#define sc(x) scanf("%lld", &x)
typedef long long ll;
#define mn 100010
ll n, m, sq, ans[mn], bin[mn], sum, a[mn], hd[mn], cnt;
ll dfn[mn], st, mx[mn], id[mn];
struct query
{
ll l, r, i;
bool operator<(const query &x) const
{
if (l / sq != x.l / sq)
return l < x.l;
return l / sq & 1 ? r < x.r : r > x.r;
}
} q[mn];
struct edge
{
ll to, nx;
} e[mn * 2];
signed main()
{
sc(n), sq = sqrt(n);
auto adde = [&](ll u, ll v)
{ e[++cnt] = {v, hd[u]}, hd[u] = cnt; };
for (ll i = 1, u, v; i < n; ++i)
{
sc(u), sc(v), adde(u, v), adde(v, u);
}
for (ll i = 1; i <= n; ++i)
{
sc(a[i]);
}
auto dfs = [&](auto self, ll u, ll fa) -> ll
{
dfn[u] = ++st, mx[u] = st, id[st] = a[u];
for (ll i = hd[u], v; i; i = e[i].nx)
{
if ((v = e[i].to) != fa)
{
mx[u] = max(mx[u], self(self, v, u));
}
}
return mx[u];
};
dfs(dfs, 1, 0);
sc(m);
for (ll i = 1, x; i <= m; ++i)
{
sc(x), q[i].l = dfn[x], q[i].r = mx[x], q[i].i = i;
}
sort(q + 1, q + 1 + m);
auto add = [&](ll i)
{ sum += bin[id[i]]++ == 0; };
auto del = [&](ll i)
{ sum -= --bin[id[i]] == 0; };
for (ll i = 1, l = 1, r = 0; i <= n; ++i)
{ //不要像我一样r不初始化犯低级错误DEBUG半个钟QwQ, r=0!r=0!r=0!
while (l > q[i].l)
add(--l);
while (r < q[i].r)
add(++r);
while (l < q[i].l)
del(l++);
while (r > q[i].r)
del(r--);
ans[q[i].i] = sum;
}
for (ll i = 1; i <= m; ++i)
{
printf("%lld\n", ans[i]);
}
return 0;
}
复杂度 O(n\log n)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define sc(x) scanf("%lld", &x)
#define mn 100010
ll n, m, cnt, dfn[mn], big[mn], siz[mn], tot, c[mn], ecnt, hd[mn];
ll lf[mn], rf[mn], s[mn], ans[mn], x;
struct edge
{
ll to, nx;
} e[mn * 2];
void adde(ll u, ll v)
{
e[++ecnt] = {v, hd[u]};
hd[u] = ecnt;
}
void dfs1(ll u, ll fa)
{
lf[u] = ++cnt, dfn[cnt] = u, siz[u] = 1;
for (ll i = hd[u], v; i; i = e[i].nx)
{
v = e[i].to;
if (v == fa)
{
continue;
}
dfs1(v, u);
siz[u] += siz[v];
if (siz[big[u]] < siz[v])
{
big[u] = v;
}
}
rf[u] = cnt;
}
void add(ll u)
{
tot += (s[c[u]]++ == 0);
}
void remove(ll u)
{
tot -= (--s[c[u]] == 0);
}
void dfs2(ll u, ll fa, bool save)
{
for (ll i = hd[u], v; i; i = e[i].nx)
{
v = e[i].to;
if (v != fa && v != big[u])
{
dfs2(v, u, false);
}
}
if (big[u])
{
dfs2(big[u], u, true);
}
for (ll i = hd[u], v; i; i = e[i].nx)
{
v = e[i].to;
if (v != fa && v != big[u])
{
for (ll j = lf[v]; j <= rf[v]; ++j)
{
add(dfn[j]);
}
}
}
add(u);
ans[u] = tot;
if (!save)
{
for (ll j = lf[u]; j <= rf[u]; ++j)
{
remove(dfn[j]);
}
}
}
signed main()
{
sc(n);
for (ll i = 1, u, v; i < n; ++i)
{
sc(u), sc(v), adde(u, v), adde(v, u);
}
for (ll i = 1; i <= n; ++i)
{
sc(c[i]);
}
dfs1(1, 0), dfs2(1, 0, false);
for (sc(m); m; --m)
{
sc(x);
printf("%lld\n", ans[x]);
}
return 0;
}
因为洛谷这题无法下载数据,这里的数据是我用脚造的,跟洛谷原题不一样,仅供调试参考
from random import *
idx, mn = 1, int(1e5)
def wri(r):
global idx
with open('ff%d.in' % idx, 'w') as f:
f.write(r)
idx += 1
def getw(n, amx=mn, ami=1):
r = '\n'
for i in range(n):
r += '%d ' % randint(ami, amx)
return r[:-1]
def getquery(n):
r = '\n%d' % n
for i in range(1, n + 1):
r += '\n%d' % i
return r
def gen_link(n, amx=1, ami=1):
r = '%d' % n
for i in range(1, n):
r += '\n%d %d' % (i, i + 1)
return r + getw(n, amx, ami) + getquery(n)
def gen_flower(n, amx=1, ami=1, rot=1):
r = '%d' % n
for i in range(1, n + 1):
if rot != i:
r += '\n%d %d' % (rot, i)
return r + getw(n, amx, ami) + getquery(n)
def gen_tree(n, amx=1, ami=1):
r = '%d' % n
for i in range(1, n):
r += '\n%d %d' % (i, randint(i + 1, n))
return r + getw(n, amx, ami) + getquery(n)
for mt in (1, 10, 100, mn):
wri(gen_tree(2, mn, mt))
wri(gen_tree(10, mt))
wri(gen_link(mn, mt))
wri(gen_flower(mn, mt))
wri(gen_flower(mn, mt, rot=580))
wri(gen_tree(mn, mt))