
go集合工具类
list与tree互转
go
package utils
// tree_util.go
import "fmt"
// TreeNode 通用树节点结构
type TreeNode[K comparable, E any] struct {
ID K `json:"id"` // 主键
ParentID K `json:"parent_id"` // 父节点ID
Name string `json:"name"` // 名称
Sort int `json:"sort"` // 排序权重
Deep int `json:"deep"` // 层级深度(从 0 开始)
Extra E `json:"extra"` // 额外数据(原始对象)
Children []*TreeNode[K, E] `json:"children"` // 子节点列表
}
// ListToTree 将扁平列表转换为树结构
// 参数:
//
// list: 扁平的节点切片
// rootParentID: 根节点的 ParentID(例如 0 或 "")
//
// 返回:
//
// 根节点列表(森林),error
func ListToTree[K comparable, T any](
list []TreeNode[K, T],
rootParentID K,
) ([]*TreeNode[K, T], error) {
nodeMap := make(map[K]*TreeNode[K, T])
var roots []*TreeNode[K, T]
// 第一步:将所有节点放入 map,方便查找
for i := range list {
node := &list[i]
nodeMap[node.ID] = node
node.Children = []*TreeNode[K, T]{} // 初始化
}
// 第二步:建立父子关系
for i := range list {
node := &list[i]
if node.ParentID == rootParentID || node.ParentID == node.ID {
// 是根节点,或指向自己
roots = append(roots, node)
} else {
parentNode, exists := nodeMap[node.ParentID]
if !exists {
return nil, fmt.Errorf("parent node not found for node ID=%v, ParentID=%v", node.ID, node.ParentID)
}
parentNode.Children = append(parentNode.Children, node)
}
// 设置深度(可选)
if node.ParentID == rootParentID {
node.Deep = 0
} else {
parentNode, exists := nodeMap[node.ParentID]
if exists {
node.Deep = parentNode.Deep + 1
} else {
node.Deep = 0 // 默认
}
}
}
return roots, nil
}
// TreeToList 将树结构展开为扁平列表(前序遍历)
func TreeToList[K comparable, T any](roots []*TreeNode[K, T]) []TreeNode[K, T] {
var result []TreeNode[K, T]
var dfs func(node *TreeNode[K, T])
dfs = func(node *TreeNode[K, T]) {
if node == nil {
return
}
// 复制节点(避免指针问题)
copied := *node
copied.Children = nil // 不包含子节点
result = append(result, copied)
for _, child := range node.Children {
dfs(child)
}
}
for _, root := range roots {
dfs(root)
}
return result
}
测试
go
// main.go
package main
import (
"demo/utils"
"fmt"
)
type Org struct {
Code string `json:"code"`
Desc string `json:"desc"`
}
func main() {
// 示例数据:组织架构
nodes := []utils.TreeNode[int, Org]{
{ID: 1, ParentID: 0, Name: "总公司", Sort: 1, Extra: Org{Code: "A001", Desc: "Headquarters"}},
{ID: 2, ParentID: 1, Name: "研发部", Sort: 10, Extra: Org{Code: "R001", Desc: "R&D"}},
{ID: 3, ParentID: 1, Name: "销售部", Sort: 20, Extra: Org{Code: "S001", Desc: "Sales"}},
{ID: 4, ParentID: 2, Name: "前端组", Sort: 11, Extra: Org{Code: "FE01", Desc: "Frontend"}},
{ID: 5, ParentID: 2, Name: "后端组", Sort: 12, Extra: Org{Code: "BE01", Desc: "Backend"}},
{ID: 6, ParentID: 3, Name: "华东区", Sort: 21, Extra: Org{Code: "EA01", Desc: "East China"}},
}
// 转成树
roots, err := utils.ListToTree(nodes, 0)
if err != nil {
panic(err)
}
fmt.Println("=== Tree (JSON) ===")
PrintTree(roots, 0)
// 再转回 list
flat := utils.TreeToList(roots)
fmt.Println("\n=== Flattened List ===")
for _, n := range flat {
fmt.Printf("[%d] %s (Parent: %d, Deep: %d)\n", n.ID, n.Name, n.ParentID, n.Deep)
}
}
// 简单打印树结构
func PrintTree[K comparable, T any](nodes []*utils.TreeNode[K, T], level int) {
indent := ""
for i := 0; i < level; i++ {
indent += " "
}
for _, n := range nodes {
fmt.Printf("%s- %s (ID: %v, Deep: %d)\n", indent, n.Name, n.ID, n.Deep)
if len(n.Children) > 0 {
PrintTree(n.Children, level+1)
}
}
}
结果
text
=== Tree (JSON) ===
- 总公司 (ID: 1, Deep: 0)
- 研发部 (ID: 2, Deep: 1)
- 前端组 (ID: 4, Deep: 2)
- 后端组 (ID: 5, Deep: 2)
- 销售部 (ID: 3, Deep: 1)
- 华东区 (ID: 6, Deep: 2)
=== Flattened List ===
[1] 总公司 (Parent: 0, Deep: 0)
[2] 研发部 (Parent: 1, Deep: 1)
[4] 前端组 (Parent: 2, Deep: 2)
[5] 后端组 (Parent: 2, Deep: 2)
[3] 销售部 (Parent: 1, Deep: 1)
[6] 华东区 (Parent: 3, Deep: 2)