Skip to content
鼓励作者:欢迎打赏犒劳

go集合工具类

list与tree互转

go
package utils

// tree_util.go

import "fmt"

// TreeNode 通用树节点结构
type TreeNode[K comparable, E any] struct {
	ID       K                 `json:"id"`        // 主键
	ParentID K                 `json:"parent_id"` // 父节点ID
	Name     string            `json:"name"`      // 名称
	Sort     int               `json:"sort"`      // 排序权重
	Deep     int               `json:"deep"`      // 层级深度(从 0 开始)
	Extra    E                 `json:"extra"`     // 额外数据(原始对象)
	Children []*TreeNode[K, E] `json:"children"`  // 子节点列表
}

// ListToTree 将扁平列表转换为树结构
// 参数:
//
//	list: 扁平的节点切片
//	rootParentID: 根节点的 ParentID(例如 0 或 "")
//
// 返回:
//
//	根节点列表(森林),error
func ListToTree[K comparable, T any](
	list []TreeNode[K, T],
	rootParentID K,
) ([]*TreeNode[K, T], error) {
	nodeMap := make(map[K]*TreeNode[K, T])
	var roots []*TreeNode[K, T]

	// 第一步:将所有节点放入 map,方便查找
	for i := range list {
		node := &list[i]
		nodeMap[node.ID] = node
		node.Children = []*TreeNode[K, T]{} // 初始化
	}

	// 第二步:建立父子关系
	for i := range list {
		node := &list[i]

		if node.ParentID == rootParentID || node.ParentID == node.ID {
			// 是根节点,或指向自己
			roots = append(roots, node)
		} else {
			parentNode, exists := nodeMap[node.ParentID]
			if !exists {
				return nil, fmt.Errorf("parent node not found for node ID=%v, ParentID=%v", node.ID, node.ParentID)
			}
			parentNode.Children = append(parentNode.Children, node)
		}

		// 设置深度(可选)
		if node.ParentID == rootParentID {
			node.Deep = 0
		} else {
			parentNode, exists := nodeMap[node.ParentID]
			if exists {
				node.Deep = parentNode.Deep + 1
			} else {
				node.Deep = 0 // 默认
			}
		}
	}

	return roots, nil
}

// TreeToList 将树结构展开为扁平列表(前序遍历)
func TreeToList[K comparable, T any](roots []*TreeNode[K, T]) []TreeNode[K, T] {
	var result []TreeNode[K, T]

	var dfs func(node *TreeNode[K, T])
	dfs = func(node *TreeNode[K, T]) {
		if node == nil {
			return
		}
		// 复制节点(避免指针问题)
		copied := *node
		copied.Children = nil // 不包含子节点
		result = append(result, copied)

		for _, child := range node.Children {
			dfs(child)
		}
	}

	for _, root := range roots {
		dfs(root)
	}

	return result
}

测试

go
// main.go
package main

import (
	"demo/utils"
	"fmt"
)

type Org struct {
	Code string `json:"code"`
	Desc string `json:"desc"`
}

func main() {
	// 示例数据:组织架构
	nodes := []utils.TreeNode[int, Org]{
		{ID: 1, ParentID: 0, Name: "总公司", Sort: 1, Extra: Org{Code: "A001", Desc: "Headquarters"}},
		{ID: 2, ParentID: 1, Name: "研发部", Sort: 10, Extra: Org{Code: "R001", Desc: "R&D"}},
		{ID: 3, ParentID: 1, Name: "销售部", Sort: 20, Extra: Org{Code: "S001", Desc: "Sales"}},
		{ID: 4, ParentID: 2, Name: "前端组", Sort: 11, Extra: Org{Code: "FE01", Desc: "Frontend"}},
		{ID: 5, ParentID: 2, Name: "后端组", Sort: 12, Extra: Org{Code: "BE01", Desc: "Backend"}},
		{ID: 6, ParentID: 3, Name: "华东区", Sort: 21, Extra: Org{Code: "EA01", Desc: "East China"}},
	}

	// 转成树
	roots, err := utils.ListToTree(nodes, 0)
	if err != nil {
		panic(err)
	}

	fmt.Println("=== Tree (JSON) ===")
	PrintTree(roots, 0)

	// 再转回 list
	flat := utils.TreeToList(roots)
	fmt.Println("\n=== Flattened List ===")
	for _, n := range flat {
		fmt.Printf("[%d] %s (Parent: %d, Deep: %d)\n", n.ID, n.Name, n.ParentID, n.Deep)
	}
}

// 简单打印树结构
func PrintTree[K comparable, T any](nodes []*utils.TreeNode[K, T], level int) {
	indent := ""
	for i := 0; i < level; i++ {
		indent += "  "
	}
	for _, n := range nodes {
		fmt.Printf("%s- %s (ID: %v, Deep: %d)\n", indent, n.Name, n.ID, n.Deep)
		if len(n.Children) > 0 {
			PrintTree(n.Children, level+1)
		}
	}
}

结果

text
=== Tree (JSON) ===
- 总公司 (ID: 1, Deep: 0)
  - 研发部 (ID: 2, Deep: 1)
    - 前端组 (ID: 4, Deep: 2)
    - 后端组 (ID: 5, Deep: 2)
  - 销售部 (ID: 3, Deep: 1)
    - 华东区 (ID: 6, Deep: 2)

=== Flattened List ===
[1] 总公司 (Parent: 0, Deep: 0)
[2] 研发部 (Parent: 1, Deep: 1)
[4] 前端组 (Parent: 2, Deep: 2)
[5] 后端组 (Parent: 2, Deep: 2)
[3] 销售部 (Parent: 1, Deep: 1)
[6] 华东区 (Parent: 3, Deep: 2)

集合判断是否相等

go
package main

import (
    "fmt"
    "reflect"
)

func main() {
    a := []string{"xx", "cc"}
    b := []string{"xx", "cc"}

    if reflect.DeepEqual(a, b) {
        fmt.Println("a 和 b 相等")
    } else {
        fmt.Println("a 和 b 不相等")
    }
}

合并两个字符串切片并去除重复元素

go
// MergeStringsUnique 合并两个字符串切片并去除重复元素
// 保持元素首次出现的顺序(先 slice1,后 slice2)
func MergeStringsUnique(slice1, slice2 []string) []string {
	seen := make(map[string]struct{}) // 使用 struct{} 避免内存开销
	var result []string

	// 处理第一个切片
	for _, item := range slice1 {
		if _, exists := seen[item]; !exists {
			seen[item] = struct{}{}
			result = append(result, item)
		}
	}

	// 处理第二个切片
	for _, item := range slice2 {
		if _, exists := seen[item]; !exists {
			seen[item] = struct{}{}
			result = append(result, item)
		}
	}

	return result
}

判断元素是否在集合中

注意:只支持 可比较的元素,str,int,bool等类型,如果struct类型的字段都是普通字段也支持。 不支持的类型有 slice、map、function

go
// Contains 判断元素 elem 是否存在于切片 slice 中
func Contains[T comparable](slice []T, elem T) bool {
    for _, v := range slice {
        if v == elem {
            return true
        }
    }
    return false
}

测试

go
package main

import "fmt"

func Contains[T comparable](slice []T, elem T) bool {
    for _, v := range slice {
        if v == elem {
            return true
        }
    }
    return false
}

func main() {
    nums := []int{1, 2, 3, 4, 5}
    fmt.Println(Contains(nums, 3)) // true
    fmt.Println(Contains(nums, 6)) // false

    words := []string{"apple", "banana", "cherry"}
    fmt.Println(Contains(words, "banana")) // true
    fmt.Println(Contains(words, "grape"))  // false
}

数组交集并集差集

工具类

shell
package set_util

// Intersection 计算两个切片的交集
// 返回同时存在于 slice1 slice2 中的元素组成的切片
func Intersection[T comparable](slice1, slice2 []T) []T {
	// 1. slice1 放入 map 以去重并建立索引
	set := make(map[T]struct{}, len(slice1))
	for _, v := range slice1 {
		set[v] = struct{}{}
	}

	// 2. 遍历 slice2,查找共同元素
	var result []T
	// 使用一个临时 map 防止结果中出现重复元素(如果 slice2 本身有重复)
	exists := make(map[T]struct{})

	for _, v := range slice2 {
		// 如果在 slice1 中存在,且在结果中尚未添加
		if _, ok := set[v]; ok {
			if _, alreadyAdded := exists[v]; !alreadyAdded {
				result = append(result, v)
				exists[v] = struct{}{}
			}
		}
	}

	return result
}

// Union 计算两个切片的并集
// 返回合并后去重的元素切片
func Union[T comparable](slice1, slice2 []T) []T {
	set := make(map[T]struct{}, len(slice1)+len(slice2))

	for _, v := range slice1 {
		set[v] = struct{}{}
	}
	for _, v := range slice2 {
		set[v] = struct{}{}
	}

	result := make([]T, 0, len(set))
	for v := range set {
		result = append(result, v)
	}
	return result
}

// Difference 计算差集 (slice1 - slice2)
// 返回存在于 slice1 但不存在于 slice2 的元素
func Difference[T comparable](slice1, slice2 []T) []T {
	set2 := make(map[T]struct{}, len(slice2))
	for _, v := range slice2 {
		set2[v] = struct{}{}
	}

	var result []T
	for _, v := range slice1 {
		if _, ok := set2[v]; !ok {
			result = append(result, v)
		}
	}
	return result
}

测试

shell
package main

import (
	"demo/utils/set_util"
	"fmt"
)

func main() {
	// 1. 测试字符串数组
	str1 := []string{"a", "b", "c", "d"}
	str2 := []string{"b", "d", "e", "f"}

	fmt.Println("字符串交集:", set_util.Intersection(str1, str2))
	// 输出: [b d]

	// 2. 测试整数数组
	int1 := []int{1, 2, 3, 4, 5}
	int2 := []int{4, 5, 6, 7}

	fmt.Println("整数交集:", set_util.Intersection(int1, int2))
	// 输出: [4 5]

	// 3. 测试并集
	fmt.Println("并集:", set_util.Union(int1, int2))
	// 输出: [1 2 3 4 5 6 7] (顺序可能不同)

	// 4. 测试差集 (int1 中有但 int2 中没有的)
	fmt.Println("差集:", set_util.Difference(int1, int2))
	// 输出: [1 2 3]
}

结果

text
字符串交集: [b d]
整数交集: [4 5]
并集: [5 2 4 6 7 1 3]
差集: [1 2 3]

如有转载或 CV 的请标注本站原文地址