Skip to content
鼓励作者:欢迎打赏犒劳

hertz中间件

中间件是 Hertz 中非常重要的概念,允许你在请求处理前后执行代码

日志中间件

该日志中间件完全可以适用于生产环境,可以打印接口的耗时,请求入参,返回值,报错信息,等

测试

go
package main

import (
	"context"
	"demo/middleware"
	"io"
	"log"
	"os"
	"time"

	"github.com/cloudwego/hertz/pkg/app"
	"github.com/cloudwego/hertz/pkg/app/server"
)

func main() {
	h := server.Default()
	//全局使用中间件
	h.Use(middleware.LoggerMiddleware())

	// 也可以为特定路由组使用
	// group := h.Group("/api")
	// group.Use(middleware.LoggerMiddleware())

	// 示例业务路由
	h.POST("/test", func(c context.Context, ctx *app.RequestContext) {
		time.Sleep(1 * time.Second)
		// 打印入参
		log.Printf("[test] request body: %s", string(ctx.Request.Body()))
		ctx.String(200, "ok")
	})
	h.POST("/upload", func(c context.Context, ctx *app.RequestContext) {
		time.Sleep(1 * time.Second)
		// 处理文件上传
		fileHeader, err := ctx.FormFile("file")
		if err != nil {
			log.Printf("[upload] failed to get file: %v", err)
			ctx.String(400, "failed to get file")
			return
		}

		// 打开上传的文件
		srcFile, err := fileHeader.Open()
		if err != nil {
			log.Printf("[upload] failed to open file: %v", err)
			ctx.String(500, "failed to open file")
			return
		}
		defer srcFile.Close()

		// 创建目标文件
		dstFile, err := os.Create(fileHeader.Filename)
		if err != nil {
			log.Printf("[upload] failed to create file: %v", err)
			ctx.String(500, "failed to create file")
			return
		}
		defer dstFile.Close()

		// 复制文件内容
		if _, err := io.Copy(dstFile, srcFile); err != nil {
			log.Printf("[upload] failed to save file: %v", err)
			ctx.String(500, "failed to save file")
			return
		}

		// 打印文件名和大小
		log.Printf("[upload] file uploaded - name: %s, size: %d bytes", fileHeader.Filename, fileHeader.Size)
		ctx.String(200, "upload success")
	})
	h.Spin()
}

LogRequestInterceptor.go

go
package middleware

import (
	"context"
	"encoding/json"
	"fmt"
	"log"
	"strings"
	"time"

	"github.com/cloudwego/hertz/pkg/app"
	"github.com/cloudwego/hertz/pkg/protocol/consts"
)

const CurrentUser = "current_user"

// LoggerMiddleware 打印请求入参、出参、耗时、当前登录人等信息的中间件
func LoggerMiddleware() app.HandlerFunc {
	return func(c context.Context, ctx *app.RequestContext) {
		// 黑名单检查:SSE 和 WebSocket 请求不打印日志
		contentType := string(ctx.Request.Header.ContentType())
		accept := string(ctx.Request.Header.Peek("Accept"))
		connection := string(ctx.Request.Header.Peek("Connection"))
		upgrade := string(ctx.Request.Header.Peek("Upgrade"))
		if isSSERequest(contentType, accept) || isWebSocketRequest(connection, upgrade) {
			ctx.Next(c)
			return
		}

		start := time.Now()

		// 1. 获取基础信息
		path := string(ctx.Request.URI().Path())
		method := string(ctx.Method())
		clientIP := ctx.ClientIP()
		query := string(ctx.Request.URI().QueryString())
		// 从上下文中获取当前登录人(假设在前置认证中间件中已存入 user 信息)
		loginUser, loginUserExists := ctx.Get(CurrentUser)

		// 2. 获取请求参数(根据 Content-Type 区分)
		var reqParams interface{}

		// 对于 multipart/form-data(文件上传),只提取文件名和大小,不读取 body
		if isMultipartFormData(contentType) {
			reqParams = parseMultipartForm(ctx)
		} else {
			// 其他类型(JSON、表单等)读取 body 并重置
			bodyBytes := ctx.Request.Body()
			// 重置 body,以便后续 handler 能够再次读取
			ctx.Request.SetBody(bodyBytes)

			switch {
			case isJSON(contentType):
				reqParams = string(bodyBytes) // 可改为解析成 map,但直接打字符串更通用
			case isForm(contentType):
				// x-www-form-urlencoded 或普通表单
				ctx.Request.PostArgs() // Hertz 中使用 PostArgs 解析表单数据
				args := ctx.Request.PostArgs()
				formData := make(map[string][]string)
				args.VisitAll(func(key, value []byte) {
					k := string(key)
					formData[k] = append(formData[k], string(value))
				})
				reqParams = formData
			default:
				reqParams = string(bodyBytes)
			}
		}

		// 3. 处理业务逻辑,捕获 panic 和错误
		var panicErr interface{}
		defer func() {
			if r := recover(); r != nil {
				panicErr = r
				ctx.AbortWithStatus(consts.StatusInternalServerError)
			}
			duration := time.Since(start)

			// 组装日志字段
			statusCode := ctx.Response.StatusCode()
			respBody := ctx.Response.Body()
			respBodySize := len(respBody)
			maxRespSize := 1024 * 10 // 10KB,超过则只打印大小
			respContentType := string(ctx.Response.Header.ContentType())

			logFields := map[string]interface{}{
				"method":      method,
				"path":        path,
				"client_ip":   clientIP,
				"query":       query,
				"req_params":  reqParams,
				"status_code": statusCode,
				"resp_size":   respBodySize,
				"duration_ms": duration.Milliseconds(),
			}

			// 判断是否为二进制文件流
			isBinary := isBinaryContentType(respContentType)

			if isBinary {
				// 二进制文件流只打印大小
				logFields["resp_body"] = "binary data, content omitted"
			} else if respBodySize > maxRespSize {
				// 响应太大,只打印大小
				logFields["resp_body"] = fmt.Sprintf("response too large, size: %d bytes", respBodySize)
			} else {
				// 正常响应,打印内容
				logFields["resp_body"] = truncateString(string(respBody), maxRespSize)
			}
			if loginUserExists {
				logFields["login_user"] = loginUser
			}

			// 错误信息(业务错误或 panic)
			var errMsg interface{}
			if panicErr != nil {
				errMsg = panicErr
			} else if len(ctx.Errors) > 0 {
				errMsg = ctx.Errors.String()
			}
			if errMsg != nil {
				logFields["error"] = errMsg
				if jsonData, err := json.Marshal(logFields); err == nil {
					log.Printf("request failed: %s", string(jsonData))
				} else {
					log.Printf("request failed: %v", logFields)
				}
			} else {
				if jsonData, err := json.Marshal(logFields); err == nil {
					log.Printf("request completed: %s", string(jsonData))
				} else {
					log.Printf("request completed: %v", logFields)
				}
			}
		}()

		// 执行下一个 handler
		ctx.Next(c)

		// 如果 handler 中显式设置了错误,Errors 会被填充,上面的 defer 会捕获
	}
}

// 判断是否为 multipart/form-data
func isMultipartFormData(contentType string) bool {
	return len(contentType) >= 19 && contentType[:19] == "multipart/form-data"
}

func isJSON(contentType string) bool {
	return len(contentType) >= 16 && contentType[:16] == "application/json"
}

func isForm(contentType string) bool {
	return len(contentType) >= 33 && contentType[:33] == "application/x-www-form-urlencoded"
}

// 判断是否为二进制内容类型
func isBinaryContentType(contentType string) bool {
	binaryTypes := []string{
		"application/octet-stream",
		"application/pdf",
		"application/zip",
		"application/x-rar-compressed",
		"application/x-7z-compressed",
		"application/vnd.ms-excel",
		"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
		"application/msword",
		"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
		"application/vnd.ms-powerpoint",
		"application/vnd.openxmlformats-officedocument.presentationml.presentation",
		"image/",
		"audio/",
		"video/",
	}

	for _, t := range binaryTypes {
		if strings.HasPrefix(strings.ToLower(contentType), t) {
			return true
		}
	}
	return false
}

// 判断是否为 SSE 请求
func isSSERequest(contentType, accept string) bool {
	lowerContentType := strings.ToLower(contentType)
	lowerAccept := strings.ToLower(accept)

	// 检查 Content-Type 是否为 text/event-stream
	if strings.Contains(lowerContentType, "text/event-stream") {
		return true
	}

	// 检查 Accept 头部是否包含 text/event-stream
	if strings.Contains(lowerAccept, "text/event-stream") {
		return true
	}

	return false
}

// 判断是否为 WebSocket 请求
func isWebSocketRequest(connection, upgrade string) bool {
	lowerConnection := strings.ToLower(connection)
	lowerUpgrade := strings.ToLower(upgrade)

	// WebSocket 请求需要 Connection: Upgrade 和 Upgrade: websocket
	hasUpgradeConnection := strings.Contains(lowerConnection, "upgrade")
	hasWebSocketUpgrade := strings.Contains(lowerUpgrade, "websocket")

	return hasUpgradeConnection && hasWebSocketUpgrade
}

// 解析 multipart/form-data,提取文件名和大小
func parseMultipartForm(ctx *app.RequestContext) interface{} {
	// 调用 ParseMultipartForm 解析(最大内存 32MB,可根据需要调整)
	// Hertz 中不需要显式调用 ParseForm,直接访问 MultipartForm 即可
	form, err := ctx.Request.MultipartForm()
	if err != nil {
		return map[string]interface{}{
			"error": "failed to parse multipart form: " + err.Error(),
		}
	}
	if form == nil {
		return "no file"
	}
	filesInfo := make([]map[string]interface{}, 0, len(form.File))
	for fieldName, headers := range form.File {
		for _, fh := range headers {
			filesInfo = append(filesInfo, map[string]interface{}{
				"field": fieldName,
				"name":  fh.Filename,
				"size":  fh.Size,
			})
		}
	}
	// 普通表单字段也可以顺便打印(可选)
	values := make(map[string][]string)
	for k, v := range form.Value {
		values[k] = v
	}
	return map[string]interface{}{
		"files":  filesInfo,
		"values": values,
	}
}

// 截断过长字符串
func truncateString(s string, maxLen int) string {
	if len(s) <= maxLen {
		return s
	}
	return s[:maxLen] + "...(truncated)"
}

认证中间件

go
package main

import (
	"context"

	"github.com/cloudwego/hertz/pkg/app"
	"github.com/cloudwego/hertz/pkg/app/server"
	"github.com/cloudwego/hertz/pkg/protocol/consts"
)

// 简单的认证中间件
func AuthMiddleware() app.HandlerFunc {
	return func(c context.Context, ctx *app.RequestContext) {
		token := string(ctx.GetHeader("Authorization"))

		if token != "Bearer secret-token" {
			ctx.JSON(consts.StatusUnauthorized, map[string]interface{}{
				"error": "Invalid or missing token",
			})
			// 中止请求处理
			ctx.Abort()
			return
		}

		// 继续处理请求
		ctx.Next(c)
	}
}

func main() {
	h := server.New()

	// 在特定路由上使用中间件
	authorized := h.Group("/api", AuthMiddleware())
	{
		authorized.GET("/data", func(c context.Context, ctx *app.RequestContext) {
			ctx.JSON(consts.StatusOK, map[string]interface{}{
				"data": "认证后的数据",
			})
		})
	}

	// 不需要认证的路由
	h.GET("/public", func(c context.Context, ctx *app.RequestContext) {
		ctx.JSON(consts.StatusOK, map[string]interface{}{
			"data": "公共数据",
		})
	})

	h.Spin()
}

打印入参中间件

go
package utils

import (
	"encoding/json"
	"fmt"
	"net/url"
	"strings"

	"github.com/cloudwego/hertz/pkg/app"
)

// RequestDump 用于结构化输出所有请求参数
type RequestDump struct {
	Query  map[string][]string    `json:"query,omitempty"`
	Path   map[string]string      `json:"path,omitempty"`
	Form   map[string][]string    `json:"form,omitempty"`
	JSON   map[string]interface{} `json:"json,omitempty"`
	Header map[string][]string    `json:"header,omitempty"` // 可选
}

// DumpRequestAll 打印 Hertz 请求中的所有参数(调试专用)
func DumpRequestAll(c *app.RequestContext) RequestDump {
	var dump RequestDump

	// 初始化字段
	dump.Query = make(map[string][]string)
	dump.Path = make(map[string]string)
	dump.Form = make(map[string][]string)
	dump.Header = make(map[string][]string)

	fmt.Println("=== HERTZ 请求参数完整打印 ===")
	contentType := string(c.Request.Header.ContentType())
	// 1. ✅ Query 参数(GET 参数)
	c.Request.URI().QueryArgs().VisitAll(func(key, value []byte) {
		k := string(key)
		dump.Query[k] = append(dump.Query[k], string(value))
	})
	if len(dump.Query) > 0 {
		printMapOfSlice(contentType, "Query Params", dump.Query)
	}

	// 2. ✅ Path 参数(如 /user/:userId)
	for _, p := range c.Params {
		dump.Path[p.Key] = p.Value
	}
	if len(dump.Path) > 0 {
		printMapOfSlice(contentType, "Path Params", dump.Path)
	}

	// 3. ✅ Header(可选打印)
	c.Request.Header.VisitAll(func(key, value []byte) {
		k := string(key)
		dump.Header[k] = append(dump.Header[k], string(value))
	})
	// 注释掉下面这行如果你不想频繁打印 header
	// if len(dump.Header) > 0 {
	// 	fmt.Println("Headers:")
	// 	printMapOfSlice(dump.Header, "  ")
	// }

	// 4. ✅ Body: JSON 或 Form
	body := c.Request.Body()

	if len(body) > 0 {
		// 解析
		if len(body) > 1024*1024 { // 超过 1MB
			fmt.Println("Body too large, skip parsing")
		} else {
			// 如果是 JSON
			if strings.Contains(contentType, "application/json") {
				var jsonData map[string]interface{}
				if err := json.Unmarshal(body, &jsonData); err == nil {
					dump.JSON = jsonData
					printMapOfSlice(contentType, "json", dump.JSON)
				} else {
					fmt.Printf("JSON 解析失败: %v\n", err)
				}
			} else if strings.Contains(contentType, "x-www-form-urlencoded") { // 如果是 Form
				// ✅ 手动解析 form body
				formData, err := url.ParseQuery(string(body))
				if err == nil {
					// 转成 map[string][]string
					for k, v := range formData {
						dump.Form[k] = v
					}
					if len(dump.Form) > 0 {
						printMapOfSlice(contentType, "form", dump.Form)
					}
				} else {
					fmt.Printf("Form 解析失败: %v\n", err)
				}
			} else if strings.Contains(contentType, "multipart/form-data") {
				// 解析 multipart 表单
				form, err := c.Request.MultipartForm()
				if err != nil {
					fmt.Printf("Failed to parse multipart form: %v\n", err)
				} else {
					// 遍历所有文本字段(非文件)
					for key, values := range form.Value {
						// values 是 []string,因为同一个 key 可能出现多次
						dump.Form[key] = values
					}
					printMapOfSlice(contentType, "multipart-form", dump.Form)
				}
			}
		}
	}

	fmt.Println("===参数打印结束 ===")
	return dump
}

// 工具函数:格式化打印 map[string][]string
func printMapOfSlice(contentType string, source string, m any) {
	marshal, _ := json.Marshal(m)
	fmt.Printf("contentType : %s , source = %s , request param :  %s\n", contentType, source, marshal)
}

打印接口请求响应耗时中间件

终极版

main 注册

go
package main

import (
	"context"
	"demo/middleware"

	"github.com/cloudwego/hertz/pkg/app"
	"github.com/cloudwego/hertz/pkg/app/server"
	"github.com/cloudwego/hertz/pkg/protocol/consts"
)

func main() {
	// 创建一个默认的 Hertz 服务器实例(默认使用 8888 端口)
	// 使用 server.Default() 会默认添加一些中间件,例如恢复中间件
	// 使用 server.New() 则可以创建一个没有任何默认中间件的纯净实例
	h := server.Default()
	h.Use(middleware.LogRequestResponse())

	// 注册一个 GET 路由处理函数
	// 第一个参数是路由路径,第二个参数是处理函数
	h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
		// 返回 JSON 响应
		ctx.JSON(consts.StatusOK, map[string]interface{}{
			"message": "pong",
			"status":  "success",
		})
	})

	// 启动服务器
	// Spin 方法会阻塞当前 goroutine,直到服务器被关闭
	h.Spin()
}

中间件

go
package middleware

import (
	"context"
	"encoding/json"
	"fmt"
	"log"
	"net/url"
	"strings"
	"time"

	"github.com/cloudwego/hertz/pkg/app"
)

// LogRequestResponse 是一个 Hertz 中间件,用于打印请求/响应和耗时
func LogRequestResponse() app.HandlerFunc {
	return func(c context.Context, ctx *app.RequestContext) {
		start := time.Now()

		// 记录入参
		reqLog := DumpRequestAll(ctx)

		// 执行业务逻辑
		ctx.Next(context.Background())

		// === 处理响应体 ===
		respBody := ctx.Response.Body()
		contentType := string(ctx.Response.Header.ContentType())
		var respLog string

		if len(respBody) == 0 {
			respLog = "{}"
		} else if isTextContent(contentType) {
			respLog = string(respBody) // 安全:这是 JSON / text
		} else {
			respLog = fmt.Sprintf("<binary> size=%d type=%s", len(respBody), contentType)
		}

		cost := time.Since(start)
		m := map[string]interface{}{
			"log_tye":  "[API LOG]",
			"method":   string(ctx.Method()),
			"path":     string(ctx.Path()),
			"request":  reqLog,
			"response": respLog,
			"cost":     cost,
		}
		jsonData, _ := json.Marshal(m)
		log.Printf("api 拦截log : %v", string(jsonData))
	}
}
func isTextContent(ct string) bool {
	return strings.Contains(ct, "application/json") ||
		strings.Contains(ct, "text/") ||
		ct == "" // 默认当作文本
}

// RequestDump 用于结构化输出所有请求参数
type RequestDump struct {
	Query  map[string][]string    `json:"query,omitempty"`
	Path   map[string]string      `json:"path,omitempty"`
	Form   map[string][]string    `json:"form,omitempty"`
	JSON   map[string]interface{} `json:"json,omitempty"`
	Header map[string][]string    `json:"header,omitempty"` // 可选
}

// DumpRequestAll 打印 Hertz 请求中的所有参数(调试专用)
func DumpRequestAll(c *app.RequestContext) RequestDump {
	var dump RequestDump

	// 初始化字段
	dump.Query = make(map[string][]string)
	dump.Path = make(map[string]string)
	dump.Form = make(map[string][]string)
	dump.Header = make(map[string][]string)

	contentType := string(c.Request.Header.ContentType())
	// 1. ✅ Query 参数(GET 参数)
	c.Request.URI().QueryArgs().VisitAll(func(key, value []byte) {
		k := string(key)
		dump.Query[k] = append(dump.Query[k], string(value))
	})
	if len(dump.Query) > 0 {
		printMapOfSlice(contentType, "Query Params", dump.Query)
	}

	// 2. ✅ Path 参数(如 /user/:userId)
	for _, p := range c.Params {
		dump.Path[p.Key] = p.Value
	}
	if len(dump.Path) > 0 {
		printMapOfSlice(contentType, "Path Params", dump.Path)
	}

	// 3. ✅ Header(可选打印)
	c.Request.Header.VisitAll(func(key, value []byte) {
		k := string(key)
		dump.Header[k] = append(dump.Header[k], string(value))
	})
	// 注释掉下面这行如果你不想频繁打印 header
	// if len(dump.Header) > 0 {
	// 	fmt.Println("Headers:")
	// 	printMapOfSlice(dump.Header, "  ")
	// }

	// 4. ✅ Body: JSON 或 Form
	body := c.Request.Body()

	if len(body) > 0 {
		// 解析
		if len(body) > 1024*1024 { // 超过 1MB
			fmt.Println("Body too large, skip parsing")
		} else {
			// 如果是 JSON
			if strings.Contains(contentType, "application/json") {
				var jsonData map[string]interface{}
				if err := json.Unmarshal(body, &jsonData); err == nil {
					dump.JSON = jsonData
					printMapOfSlice(contentType, "json", dump.JSON)
				} else {
					fmt.Printf("JSON 解析失败: %v\n", err)
				}
			} else if strings.Contains(contentType, "x-www-form-urlencoded") { // 如果是 Form
				// ✅ 手动解析 form body
				formData, err := url.ParseQuery(string(body))
				if err == nil {
					// 转成 map[string][]string
					for k, v := range formData {
						dump.Form[k] = v
					}
					if len(dump.Form) > 0 {
						printMapOfSlice(contentType, "form", dump.Form)
					}
				} else {
					fmt.Printf("Form 解析失败: %v\n", err)
				}
			} else if strings.Contains(contentType, "multipart/form-data") {
				// 解析 multipart 表单
				form, err := c.Request.MultipartForm()
				if err != nil {
					fmt.Printf("Failed to parse multipart form: %v\n", err)
				} else {
					// 遍历所有文本字段(非文件)
					for key, values := range form.Value {
						// values 是 []string,因为同一个 key 可能出现多次
						dump.Form[key] = values
					}
					printMapOfSlice(contentType, "multipart-form", dump.Form)
				}
			}
		}
	}

	return dump
}

// 工具函数:格式化打印 map[string][]string
func printMapOfSlice(contentType string, source string, m any) {
	marshal, _ := json.Marshal(m)
	fmt.Printf("contentType : %s , source = %s , request param :  %s\n", contentType, source, marshal)
}

如有转载或 CV 的请标注本站原文地址