2026-05-04:树组的交互代价总和。用go语言,给定一个整数 n,以及一棵有 n 个节点的无向树,节点编号为 0 到 n-1。树的结构由数组 edges

🤖 AI总结

主题

使用虚树和贡献法高效计算树中同组节点对的路径长度总和。

摘要

文章讲解了利用虚树和贡献法高效计算树中同组节点对路径长度总和的问题,适用于n=1e5且组数≤20的场景。

关键信息

  • 1 给定一棵树和每个节点的组号,计算所有同组节点对之间路径的边数总和。
  • 2 利用组数不多于20的条件,对每个组构建虚树,并使用贡献法(每条边被经过的次数)累加代价。
  • 3 时间复杂度O(n log n),空间复杂度O(n log n),可处理n=1e5。

2026-05-04:树组的交互代价总和。用go语言,给定一个整数 n,以及一棵有 n 个节点的无向树,节点编号为 0 到 n-1。树的结构由数组 edges 表示:数组长度为 n-1,其中 edges[i] = [u, v] 表示节点 u 与节点 v 之间有一条无向边。

另给定一个数组 group,长度为 n。group[i] 表示节点 i 所属的组。

如果两个节点 u 和 v 满足 group[u] == group[v],则称它们属于同一组。由于是树结构,任意两个节点之间都存在且仅存在一条唯一路径。所谓交互代价定义为:这条唯一路径上包含的边的数量。

目标:枚举所有无序且不同的节点对 (u, v),要求它们同组(group[u] == group[v])。把这些节点对的交互代价全部累加并返回总和;若不存在满足条件的节点对,则返回 0。

1 <= n <= 100000。

edges.length == n – 1。

edges[i] = [ui, vi]。

0 <= ui, vi <= n – 1。

group.length == n。

1 <= group[i] <= 20。

输入保证 edges 表示一棵有效的树。

输入: n = 3, edges = [[0,1],[1,2]], group = [3,2,3]。

输出: 2。

解释:

节点 0 和节点 2 属于组 3,它们之间的交互代价为 2。

节点 1 属于不同的组,因此没有有效的节点对。

总交互代价为 2。

题目来自力扣3786。

代码执行全流程详细拆解第一步:构建原始树的邻接表

1. 初始化一个长度为n的二维数组,作为无向树的邻接表

  • 2. 遍历所有边,把每条边的两个节点互相添加到对方的邻接列表中;

  • 3. 作用:让程序可以快速访问每个节点的所有相邻节点,为后续树的遍历做准备。

    第二步:预处理树的基础信息(DFS + 倍增LCA)

    这一步是为了快速求任意两点的最近公共祖先(LCA)两点间路径长度,是树问题的核心预处理。

    子步骤1:DFS遍历树,记录核心信息

    1. 定义递归函数,从根节点0开始遍历整棵树;

  • 2. 记录每个节点的DFS时间戳(dfn):用于后续给节点排序,是构建虚树的关键;

  • 3. 记录每个节点的父节点(第一层祖先)

  • 4. 记录每个节点的深度(dep):根节点深度为0,子节点深度=父节点+1;

  • 5. 深度的作用:两点间路径边数 =深度[u] + 深度[v] - 2×深度[LCA(u,v)]

    子步骤2:倍增法预处理祖先数组

    1. 计算树的最大深度对应的二进制位数,确定倍增的层数;

  • 2. 预处理每个节点的2^i级祖先(2级、4级、8级…祖先);

  • 3. 作用:实现O(logn)时间查询任意两个节点的最近公共祖先(LCA)。

    子步骤3:封装LCA查询函数

    1. 封装两个工具函数:

    • 把节点向上提升到指定深度;

  • • 求任意两个节点的最近公共祖先;

    2. 这是后续计算路径长度、构建虚树的基础工具。

    第三步:按分组归类所有节点

    1. 创建一个哈希表(字典),key=组号,value=该组所有节点的列表

  • 2. 遍历所有节点,把每个节点按照group数组的值,放入对应组的列表中;

  • 3. 作用:后续只需要逐组计算,组号最多20个,极大减少计算量。

    第四步:对每一组单独构建「虚树」(核心优化)

    因为组号只有20个,我们逐个组处理,每组独立计算贡献:

    虚树作用:只保留当前组的节点 + 这些节点之间路径的必要公共祖先,剔除无关节点,把大树压缩成小树,大幅降低计算量。

    单组构建虚树的完整步骤

    1.按DFS时间戳排序:把当前组的所有节点,按照第一步记录的dfn从小到大排序;

  • 2.初始化栈和虚树:用根节点作为栈的初始元素,清空虚树结构;

  • 3.标记关键节点:把当前组的节点标记为「真实关键节点」;

  • 4.栈+LCA构建虚树

    • 遍历排序后的每个节点;

  • • 计算栈顶节点与当前节点的LCA(路径拐点);

  • • 不断回溯栈,给虚树添加边,直到栈顶深度小于LCA深度;

  • • 如果LCA不在栈中,将其加入栈和虚树;

  • • 最后把当前节点入栈;

    5.收尾加边:遍历结束后,把栈中剩余节点依次连边,完成虚树构建。

    第五步:在虚树上DFS计算本组的交互代价(贡献法)

    这是计算答案的核心步骤,使用贡献法:不枚举所有节点对(会超时),而是计算每条边被多少对节点经过,总代价 = 边数 × 经过的节点对数。

    计算步骤

    1. 定义递归DFS函数,遍历当前组的虚树;

  • 2. 递归统计每个子树中当前组的节点数量

  • 3. 对于虚树上的每一条边

    • 边的实际长度 = 子节点深度 – 父节点深度(对应原始树的边数);

  • • 设子树内有sz个本组节点,本组总节点数为total

  • • 这条边会被sz × (total - sz)对节点经过;

  • • 本组总代价 += 边长度 ×sz × (total - sz)

    4. 把本组的代价累加到全局答案中;

    5. 一组计算完成后,重置虚树,开始处理下一组。

    第六步:所有组计算完成,返回最终答案

    1. 遍历完所有组(最多20组);

  • 2. 全局累加的结果就是所有同组节点对的交互代价总和

  • 3. 示例中仅组3贡献了2,最终输出2。

    时间复杂度 & 额外空间复杂度分析 一、总时间复杂度

    O(n × logn + G × k × logk)
    拆解说明:

    1.预处理LCA:DFS遍历树是O(n),倍增预处理是O(n × logn)

  • 2.分组归类O(n)

  • 3.构建虚树+计算贡献

    • 组数量G ≤ 20(题目限定);

  • • 每组节点数k,排序O(k logk),构建虚树+DFSO(k)

  • • 所有组总耗时O(n logn)

    4. 整体主导项:O(n × logn)完全满足n=1e5的时间要求

    二、总额外空间复杂度

    O(n × logn)
    拆解说明:

    1. 邻接表:O(n)

  • 2. 倍增祖先数组:O(n × 17)(log₂(1e5)≈17),是核心空间开销;

  • 3. DFN、深度数组、虚树、栈、哈希表:均为O(n)

  • 4. 整体空间复杂度由倍增数组主导:O(n × logn)

    总结

    1. 整体流程:建原始树 → 预处理LCA → 按组分节点 → 每组建虚树压缩 → 贡献法算代价 → 累加答案

  • 2. 核心优化:利用group[i]≤20的限定,逐组处理+虚树压缩,避免暴力枚举节点对;

  • 3. 时间复杂度:O(n logn),高效处理1e5节点;

  • 4. 空间复杂度:O(n logn),符合算法题常规空间要求。

    Go完整代码如下:

    package main

    import (
    "fmt"
    "math/bits"
    "slices"
    )

    func interactionCosts(n int, edges [][]int, group []int) (ans int64) {
    g := make([][]int, n)
    for _, e := range edges {
    v, w := e[0], e[1]
    g[v] = append(g[v], w)
    g[w] = append(g[w], v)
    }

    dfn := make([]int, n)
    ts := 0
    pa := make([][17]int, n)
    dep := make([]int, n)
    var build func(int, int)
    build = func(v, p int) {
    dfn[v] = ts
    ts++
    pa[v][0] = p
    for _, w := range g[v] {
    if w != p {
    dep[w] = dep[v] + 1
    build(w, v)
    }
    }
    }
    build(0, -1)
    mx := bits.Len(uint(n))
    for i := range mx - 1 {
    for v := range pa {
    p := pa[v][i]
    if p != -1 {
    pa[v][i+1] = pa[p][i]
    } else {
    pa[v][i+1] = -1
    }
    }
    }
    uptoDep := func(v, d int)int {
    for k := uint32(dep[v] - d); k > 0; k &= k - 1 {
    v = pa[v][bits.TrailingZeros32(k)]
    }
    return v
    }
    getLCA := func(v, w int)int {
    if dep[v] > dep[w] {
    v, w = w, v
    }
    w = uptoDep(w, dep[v])
    if w == v {
    return v
    }
    for i := mx - 1; i >= 0; i-- {
    pv, pw := pa[v][i], pa[w][i]
    if pv != pw {
    v, w = pv, pw
    }
    }
    return pa[v][0]
    }

    nodesMap := map[int][]int{}
    for i, x := range group {
    nodesMap[x] = append(nodesMap[x], i)
    }

    vt := make([][]int, n) // 虚树
    isNode := make([]int, n) // 用来区分是关键节点还是 LCA
    for i := range isNode {
    isNode[i] = -1
    }
    addVtEdge := func(v, w int) {
    vt[v] = append(vt[v], w) // 往虚树上添加一条有向边
    }
    const root = 0
    st := []int{root} // 用根节点作为栈底哨兵

    for val, nodes := range nodesMap {
    // 对于相同点权的这一组关键节点 nodes,构建虚树
    slices.SortFunc(nodes, func(a, b int)int { return dfn[a] - dfn[b] })
    vt[root] = vt[root][:0] // 重置虚树
    st = st[:1]
    for _, v := range nodes {
    isNode[v] = val
    if v == root {
    continue
    }
    vt[v] = vt[v][:0]
    lca := getLCA(st[len(st)-1], v) // 路径的拐点(LCA)也加到虚树中
    // 回溯,加边
    forlen(st) > 1 && dfn[lca] <= dfn[st[len(st)-2]] {
    addVtEdge(st[len(st)-2], st[len(st)-1])
    st = st[:len(st)-1]
    }
    if lca != st[len(st)-1] { // lca 不在栈中(首次遇到)
    vt[lca] = vt[lca][:0]
    addVtEdge(lca, st[len(st)-1])
    st[len(st)-1] = lca // 加到栈中
    }
    st = append(st, v)
    }
    // 最后的回溯,加边
    for i := 1; i < len(st); i++ {
    addVtEdge(st[i-1], st[i])
    }

    var dfs func(int)int
    dfs = func(v int) (size int) {
    // 如果 isNode[v] != t,那么 v 只是关键节点之间路径上的「拐点」
    if isNode[v] == val {
    size = 1
    }
    for _, w := range vt[v] {
    sz := dfs(w)
    wt := dep[w] - dep[v] // 虚树边权
    // 贡献法
    ans += int64(wt) * int64(sz) * int64(len(nodes)-sz)
    size += sz
    }
    return
    }

    rt := root
    if isNode[rt] != val && len(vt[rt]) == 1 {
    // 注意 root 只是一个哨兵,不一定在虚树上,得从真正的根节点开始
    rt = vt[rt][0]
    }
    dfs(rt)
    }

    return
    }

    func main() {
    n := 3
    edges := [][]int{{0, 1}, {1, 2}}
    group := []int{3, 2, 3}
    result := interactionCosts(n, edges, group)
    fmt.Println(result)
    }

    2026-05-04:树组的交互代价总和。用go语言,给定一个整数 n,以及一棵有 n 个节点的无向树,节点编号为 0 到 n-1。树的结构由数组 edges

    Python完整代码如下:

    # -*-coding:utf-8-*-

    import sys
    sys.setrecursionlimit(10**6)

    def interactionCosts(n, edges, group):
    ans = 0

    g = [[] for _ in range(n)]
    for v, w in edges:
    g[v].append(w)
    g[w].append(v)

    dfn = [0] * n
    ts = 0
    pa = [[-1] * 17for _ in range(n)]
    dep = [0] * n

    def build(v, p):
    nonlocal ts
    dfn[v] = ts
    ts += 1
    pa[v][0] = p
    for w in g[v]:
    if w != p:
    dep[w] = dep[v] + 1
    build(w, v)

    build(0, -1)

    mx = n.bit_length()
    for i in range(mx - 1):
    for v in range(n):
    p = pa[v][i]
    if p != -1:
    pa[v][i+1] = pa[p][i]
    else:
    pa[v][i+1] = -1

    def uptoDep(v, d):
    k = dep[v] - d
    while k:
    step = (k & -k).bit_length() - 1
    v = pa[v][step]
    k -= (1 << step)
    return v

    def getLCA(v, w):
    if dep[v] > dep[w]:
    v, w = w, v
    w = uptoDep(w, dep[v])
    if w == v:
    return v
    for i in range(mx-1, -1, -1):
    pv, pw = pa[v][i], pa[w][i]
    if pv != pw:
    v, w = pv, pw
    return pa[v][0]

    nodesMap = {}
    for i, x in enumerate(group):
    nodesMap.setdefault(x, []).append(i)

    vt = [[] for _ in range(n)]
    isNode = [-1] * n

    def addVtEdge(v, w):
    vt[v].append(w)

    root = 0
    st = [root]

    for val, nodes in nodesMap.items():
    nodes.sort(key=lambda x: dfn[x])
    vt[root] = []
    st = [root]
    for v in nodes:
    isNode[v] = val
    if v == root:
    continue
    vt[v] = []
    lca = getLCA(st[-1], v)
    while len(st) > 1 and dfn[lca] <= dfn[st[-2]]:
    addVtEdge(st[-2], st[-1])
    st.pop()
    if lca != st[-1]:
    vt[lca] = []
    addVtEdge(lca, st[-1])
    st[-1] = lca
    st.append(v)
    for i in range(1, len(st)):
    addVtEdge(st[i-1], st[i])

    sys.setrecursionlimit(10**6)
    def dfs(v):
    nonlocal ans
    size = 1if isNode[v] == val else0
    for w in vt[v]:
    sz = dfs(w)
    wt = dep[w] - dep[v]
    ans += wt * sz * (len(nodes) - sz)
    size += sz
    return size

    rt = root
    if isNode[rt] != val and len(vt[rt]) == 1:
    rt = vt[rt][0]
    dfs(rt)

    return ans

    if __name__ == "__main__":
    n = 3
    edges = [[0, 1], [1, 2]]
    group = [3, 2, 3]
    result = interactionCosts(n, edges, group)
    print(result)

    2026-05-04:树组的交互代价总和。用go语言,给定一个整数 n,以及一棵有 n 个节点的无向树,节点编号为 0 到 n-1。树的结构由数组 edges

    C++完整代码如下:

      
    





    using namespace std;

    class Solution {
    public:
    long long interactionCosts(int n, vector int >>& edges, vector< int >& group) {
    long long ans = 0 ;

    // 构建邻接表
    vector int >> g(n);
    for (auto& e : edges) {
    int v = e[ 0 ], w = e[ 1 ];
    g[v].push_back(w);
    g[w].push_back(v);
    }

    // 预处理 DFS 序、深度和倍增祖先
    vector< int > dfn(n, 0 );
    int ts = 0 ;
    vector int , 17 >> pa(n);
    vector< int > dep(n, 0 );

    function int , int )> build = [&]( int v, int p) {
    dfn[v] = ts++;
    pa[v][ 0 ] = p;
    for ( int w : g[v]) {
    if (w != p) {
    dep[w] = dep[v] + 1 ;
    build(w, v);
    }
    }
    };
    build( 0 , -1 );

    int mx = 32 - __builtin_clz(n); // bits.Len(uint(n))
    for ( int i = 0 ; i < mx - 1 ; i++) {
    for ( int v = 0 ; v < n; v++) {
    int p = pa[v][i];
    if (p != -1 ) {
    pa[v][i + 1 ] = pa[p][i];
    } else {
    pa[v][i + 1 ] = -1 ;
    }
    }
    }

    // 跳到指定深度
    auto uptoDep = [&]( int v, int d) -> int {
    int k = dep[v] - d;
    while (k > 0 ) {
    int step = __builtin_ctz(k);
    v = pa[v][step];
    k &= k - 1 ;
    }
    return v;
    };

    // 获取 LCA
    auto getLCA = [&]( int v, int w) -> int {
    if (dep[v] > dep[w]) {
    swap(v, w);
    }
    w = uptoDep(w, dep[v]);
    if (w == v) return v;
    for ( int i = mx - 1 ; i >= 0 ; i--) {
    int pv = pa[v][i], pw = pa[w][i];
    if (pv != pw) {
    v = pv;
    w = pw;
    }
    }
    return pa[v][ 0 ];
    };

    // 按点权分组节点
    map < int , vector< int >> nodesMap;
    for ( int i = 0 ; i < n; i++) {
    nodesMap[group[i]].push_back(i);
    }

    // 虚树
    vector int >> vt(n);
    vector< int > isNode(n, -1 );

    auto addVtEdge = [&]( int v, int w) {
    vt[v].push_back(w);
    };

    const int root = 0 ;
    vector< int > st;

    // 处理每个点权组
    for (auto& [val, nodes] : nodesMap) {
    // 按 DFS 序排序
    sort(nodes.begin(), nodes.end(), [&]( int a, int b) {
    return dfn[a] < dfn[b];
    });

    // 清空虚树
    for ( int v : nodes) {
    vt[v].clear();
    }
    vt[root].clear();

    st.clear();
    st.push_back(root);

    // 构建虚树
    for ( int v : nodes) {
    isNode[v] = val;
    if (v == root) continue ;

    vt[v].clear();
    int lca = getLCA(st.back(), v);

    // 回溯并加边
    while (st.size() > 1 && dfn[lca] <= dfn[st[st.size() - 2 ]]) {
    addVtEdge(st[st.size() - 2 ], st.back());
    st.pop_back();
    }

    if (lca != st.back()) {
    vt[lca].clear();
    addVtEdge(lca, st.back());
    st.back() = lca;
    }

    st.push_back(v);
    }

    // 添加剩余边
    for ( int i = 1 ; i < st.size(); i++) {
    addVtEdge(st[i - 1 ], st[i]);
    }

    // DFS 遍历虚树计算贡献
    function< int ( int )> dfs = [&]( int v) -> int {
    int size = (isNode[v] == val) ? 1 : 0 ;
    for ( int w : vt[v]) {
    int sz = dfs(w);
    int wt = dep[w] - dep[v];
    ans += 1 LL * wt * sz * (nodes.size() - sz);
    size += sz;
    }
    return size;
    };

    // 找到真正的根节点
    int rt = root;
    if (isNode[rt] != val && vt[rt].size() == 1 ) {
    rt = vt[rt][ 0 ];
    }
    dfs(rt);
    }

    return ans;
    }
    };

    int main() {
    int n = 3 ;
    vector int >> edges = {{ 0 , 1 }, { 1 , 2 }};
    vector< int > group = { 3 , 2 , 3 };

    Solution solution;
    long long result = solution.interactionCosts(n, edges, group);
    cout << result << endl;

    return 0 ;
    }

    2026-05-04:树组的交互代价总和。用go语言,给定一个整数 n,以及一棵有 n 个节点的无向树,节点编号为 0 到 n-1。树的结构由数组 edges

    我们相信人工智能为普通人提供了一种“增强工具”,并致力于分享全方位的AI知识。在这里,您可以找到最新的AI科普文章、工具评测、提升效率的秘籍以及行业洞察。 欢迎关注“福大大架构师每日一题”,发消息可获得面试资料,让AI助力您的未来发展。

    © 版权声明

    相关文章