Skip to content
鼓励作者:欢迎打赏犒劳

go集合工具类

list与tree互转

go
package utils

// tree_util.go

import "fmt"

// TreeNode 通用树节点结构
type TreeNode[K comparable, E any] struct {
	ID       K                 `json:"id"`        // 主键
	ParentID K                 `json:"parent_id"` // 父节点ID
	Name     string            `json:"name"`      // 名称
	Sort     int               `json:"sort"`      // 排序权重
	Deep     int               `json:"deep"`      // 层级深度(从 0 开始)
	Extra    E                 `json:"extra"`     // 额外数据(原始对象)
	Children []*TreeNode[K, E] `json:"children"`  // 子节点列表
}

// ListToTree 将扁平列表转换为树结构
// 参数:
//
//	list: 扁平的节点切片
//	rootParentID: 根节点的 ParentID(例如 0 或 "")
//
// 返回:
//
//	根节点列表(森林),error
func ListToTree[K comparable, T any](
	list []TreeNode[K, T],
	rootParentID K,
) ([]*TreeNode[K, T], error) {
	nodeMap := make(map[K]*TreeNode[K, T])
	var roots []*TreeNode[K, T]

	// 第一步:将所有节点放入 map,方便查找
	for i := range list {
		node := &list[i]
		nodeMap[node.ID] = node
		node.Children = []*TreeNode[K, T]{} // 初始化
	}

	// 第二步:建立父子关系
	for i := range list {
		node := &list[i]

		if node.ParentID == rootParentID || node.ParentID == node.ID {
			// 是根节点,或指向自己
			roots = append(roots, node)
		} else {
			parentNode, exists := nodeMap[node.ParentID]
			if !exists {
				return nil, fmt.Errorf("parent node not found for node ID=%v, ParentID=%v", node.ID, node.ParentID)
			}
			parentNode.Children = append(parentNode.Children, node)
		}

		// 设置深度(可选)
		if node.ParentID == rootParentID {
			node.Deep = 0
		} else {
			parentNode, exists := nodeMap[node.ParentID]
			if exists {
				node.Deep = parentNode.Deep + 1
			} else {
				node.Deep = 0 // 默认
			}
		}
	}

	return roots, nil
}

// TreeToList 将树结构展开为扁平列表(前序遍历)
func TreeToList[K comparable, T any](roots []*TreeNode[K, T]) []TreeNode[K, T] {
	var result []TreeNode[K, T]

	var dfs func(node *TreeNode[K, T])
	dfs = func(node *TreeNode[K, T]) {
		if node == nil {
			return
		}
		// 复制节点(避免指针问题)
		copied := *node
		copied.Children = nil // 不包含子节点
		result = append(result, copied)

		for _, child := range node.Children {
			dfs(child)
		}
	}

	for _, root := range roots {
		dfs(root)
	}

	return result
}

测试

go
// main.go
package main

import (
	"demo/utils"
	"fmt"
)

type Org struct {
	Code string `json:"code"`
	Desc string `json:"desc"`
}

func main() {
	// 示例数据:组织架构
	nodes := []utils.TreeNode[int, Org]{
		{ID: 1, ParentID: 0, Name: "总公司", Sort: 1, Extra: Org{Code: "A001", Desc: "Headquarters"}},
		{ID: 2, ParentID: 1, Name: "研发部", Sort: 10, Extra: Org{Code: "R001", Desc: "R&D"}},
		{ID: 3, ParentID: 1, Name: "销售部", Sort: 20, Extra: Org{Code: "S001", Desc: "Sales"}},
		{ID: 4, ParentID: 2, Name: "前端组", Sort: 11, Extra: Org{Code: "FE01", Desc: "Frontend"}},
		{ID: 5, ParentID: 2, Name: "后端组", Sort: 12, Extra: Org{Code: "BE01", Desc: "Backend"}},
		{ID: 6, ParentID: 3, Name: "华东区", Sort: 21, Extra: Org{Code: "EA01", Desc: "East China"}},
	}

	// 转成树
	roots, err := utils.ListToTree(nodes, 0)
	if err != nil {
		panic(err)
	}

	fmt.Println("=== Tree (JSON) ===")
	PrintTree(roots, 0)

	// 再转回 list
	flat := utils.TreeToList(roots)
	fmt.Println("\n=== Flattened List ===")
	for _, n := range flat {
		fmt.Printf("[%d] %s (Parent: %d, Deep: %d)\n", n.ID, n.Name, n.ParentID, n.Deep)
	}
}

// 简单打印树结构
func PrintTree[K comparable, T any](nodes []*utils.TreeNode[K, T], level int) {
	indent := ""
	for i := 0; i < level; i++ {
		indent += "  "
	}
	for _, n := range nodes {
		fmt.Printf("%s- %s (ID: %v, Deep: %d)\n", indent, n.Name, n.ID, n.Deep)
		if len(n.Children) > 0 {
			PrintTree(n.Children, level+1)
		}
	}
}

结果

text
=== Tree (JSON) ===
- 总公司 (ID: 1, Deep: 0)
  - 研发部 (ID: 2, Deep: 1)
    - 前端组 (ID: 4, Deep: 2)
    - 后端组 (ID: 5, Deep: 2)
  - 销售部 (ID: 3, Deep: 1)
    - 华东区 (ID: 6, Deep: 2)

=== Flattened List ===
[1] 总公司 (Parent: 0, Deep: 0)
[2] 研发部 (Parent: 1, Deep: 1)
[4] 前端组 (Parent: 2, Deep: 2)
[5] 后端组 (Parent: 2, Deep: 2)
[3] 销售部 (Parent: 1, Deep: 1)
[6] 华东区 (Parent: 3, Deep: 2)

如有转载或 CV 的请标注本站原文地址