
hertz中间件
中间件是 Hertz 中非常重要的概念,允许你在请求处理前后执行代码
日志中间件
该日志中间件完全可以适用于生产环境,可以打印接口的耗时,请求入参,返回值,报错信息,等
测试
go
package main
import (
"context"
"demo/middleware"
"io"
"log"
"os"
"time"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
)
func main() {
h := server.Default()
//全局使用中间件
h.Use(middleware.LoggerMiddleware())
// 也可以为特定路由组使用
// group := h.Group("/api")
// group.Use(middleware.LoggerMiddleware())
// 示例业务路由
h.POST("/test", func(c context.Context, ctx *app.RequestContext) {
time.Sleep(1 * time.Second)
// 打印入参
log.Printf("[test] request body: %s", string(ctx.Request.Body()))
ctx.String(200, "ok")
})
h.POST("/upload", func(c context.Context, ctx *app.RequestContext) {
time.Sleep(1 * time.Second)
// 处理文件上传
fileHeader, err := ctx.FormFile("file")
if err != nil {
log.Printf("[upload] failed to get file: %v", err)
ctx.String(400, "failed to get file")
return
}
// 打开上传的文件
srcFile, err := fileHeader.Open()
if err != nil {
log.Printf("[upload] failed to open file: %v", err)
ctx.String(500, "failed to open file")
return
}
defer srcFile.Close()
// 创建目标文件
dstFile, err := os.Create(fileHeader.Filename)
if err != nil {
log.Printf("[upload] failed to create file: %v", err)
ctx.String(500, "failed to create file")
return
}
defer dstFile.Close()
// 复制文件内容
if _, err := io.Copy(dstFile, srcFile); err != nil {
log.Printf("[upload] failed to save file: %v", err)
ctx.String(500, "failed to save file")
return
}
// 打印文件名和大小
log.Printf("[upload] file uploaded - name: %s, size: %d bytes", fileHeader.Filename, fileHeader.Size)
ctx.String(200, "upload success")
})
h.Spin()
}LogRequestInterceptor.go
go
package middleware
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"time"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol/consts"
)
const CurrentUser = "current_user"
// LoggerMiddleware 打印请求入参、出参、耗时、当前登录人等信息的中间件
func LoggerMiddleware() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
// 黑名单检查:SSE 和 WebSocket 请求不打印日志
contentType := string(ctx.Request.Header.ContentType())
accept := string(ctx.Request.Header.Peek("Accept"))
connection := string(ctx.Request.Header.Peek("Connection"))
upgrade := string(ctx.Request.Header.Peek("Upgrade"))
if isSSERequest(contentType, accept) || isWebSocketRequest(connection, upgrade) {
ctx.Next(c)
return
}
start := time.Now()
// 1. 获取基础信息
path := string(ctx.Request.URI().Path())
method := string(ctx.Method())
clientIP := ctx.ClientIP()
query := string(ctx.Request.URI().QueryString())
// 从上下文中获取当前登录人(假设在前置认证中间件中已存入 user 信息)
loginUser, loginUserExists := ctx.Get(CurrentUser)
// 2. 获取请求参数(根据 Content-Type 区分)
var reqParams interface{}
// 对于 multipart/form-data(文件上传),只提取文件名和大小,不读取 body
if isMultipartFormData(contentType) {
reqParams = parseMultipartForm(ctx)
} else {
// 其他类型(JSON、表单等)读取 body 并重置
bodyBytes := ctx.Request.Body()
// 重置 body,以便后续 handler 能够再次读取
ctx.Request.SetBody(bodyBytes)
switch {
case isJSON(contentType):
reqParams = string(bodyBytes) // 可改为解析成 map,但直接打字符串更通用
case isForm(contentType):
// x-www-form-urlencoded 或普通表单
ctx.Request.PostArgs() // Hertz 中使用 PostArgs 解析表单数据
args := ctx.Request.PostArgs()
formData := make(map[string][]string)
args.VisitAll(func(key, value []byte) {
k := string(key)
formData[k] = append(formData[k], string(value))
})
reqParams = formData
default:
reqParams = string(bodyBytes)
}
}
// 3. 处理业务逻辑,捕获 panic 和错误
var panicErr interface{}
defer func() {
if r := recover(); r != nil {
panicErr = r
ctx.AbortWithStatus(consts.StatusInternalServerError)
}
duration := time.Since(start)
// 组装日志字段
statusCode := ctx.Response.StatusCode()
respBody := ctx.Response.Body()
respBodySize := len(respBody)
maxRespSize := 1024 * 10 // 10KB,超过则只打印大小
respContentType := string(ctx.Response.Header.ContentType())
logFields := map[string]interface{}{
"method": method,
"path": path,
"client_ip": clientIP,
"query": query,
"req_params": reqParams,
"status_code": statusCode,
"resp_size": respBodySize,
"duration_ms": duration.Milliseconds(),
}
// 判断是否为二进制文件流
isBinary := isBinaryContentType(respContentType)
if isBinary {
// 二进制文件流只打印大小
logFields["resp_body"] = "binary data, content omitted"
} else if respBodySize > maxRespSize {
// 响应太大,只打印大小
logFields["resp_body"] = fmt.Sprintf("response too large, size: %d bytes", respBodySize)
} else {
// 正常响应,打印内容
logFields["resp_body"] = truncateString(string(respBody), maxRespSize)
}
if loginUserExists {
logFields["login_user"] = loginUser
}
// 错误信息(业务错误或 panic)
var errMsg interface{}
if panicErr != nil {
errMsg = panicErr
} else if len(ctx.Errors) > 0 {
errMsg = ctx.Errors.String()
}
if errMsg != nil {
logFields["error"] = errMsg
if jsonData, err := json.Marshal(logFields); err == nil {
log.Printf("request failed: %s", string(jsonData))
} else {
log.Printf("request failed: %v", logFields)
}
} else {
if jsonData, err := json.Marshal(logFields); err == nil {
log.Printf("request completed: %s", string(jsonData))
} else {
log.Printf("request completed: %v", logFields)
}
}
}()
// 执行下一个 handler
ctx.Next(c)
// 如果 handler 中显式设置了错误,Errors 会被填充,上面的 defer 会捕获
}
}
// 判断是否为 multipart/form-data
func isMultipartFormData(contentType string) bool {
return len(contentType) >= 19 && contentType[:19] == "multipart/form-data"
}
func isJSON(contentType string) bool {
return len(contentType) >= 16 && contentType[:16] == "application/json"
}
func isForm(contentType string) bool {
return len(contentType) >= 33 && contentType[:33] == "application/x-www-form-urlencoded"
}
// 判断是否为二进制内容类型
func isBinaryContentType(contentType string) bool {
binaryTypes := []string{
"application/octet-stream",
"application/pdf",
"application/zip",
"application/x-rar-compressed",
"application/x-7z-compressed",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
"image/",
"audio/",
"video/",
}
for _, t := range binaryTypes {
if strings.HasPrefix(strings.ToLower(contentType), t) {
return true
}
}
return false
}
// 判断是否为 SSE 请求
func isSSERequest(contentType, accept string) bool {
lowerContentType := strings.ToLower(contentType)
lowerAccept := strings.ToLower(accept)
// 检查 Content-Type 是否为 text/event-stream
if strings.Contains(lowerContentType, "text/event-stream") {
return true
}
// 检查 Accept 头部是否包含 text/event-stream
if strings.Contains(lowerAccept, "text/event-stream") {
return true
}
return false
}
// 判断是否为 WebSocket 请求
func isWebSocketRequest(connection, upgrade string) bool {
lowerConnection := strings.ToLower(connection)
lowerUpgrade := strings.ToLower(upgrade)
// WebSocket 请求需要 Connection: Upgrade 和 Upgrade: websocket
hasUpgradeConnection := strings.Contains(lowerConnection, "upgrade")
hasWebSocketUpgrade := strings.Contains(lowerUpgrade, "websocket")
return hasUpgradeConnection && hasWebSocketUpgrade
}
// 解析 multipart/form-data,提取文件名和大小
func parseMultipartForm(ctx *app.RequestContext) interface{} {
// 调用 ParseMultipartForm 解析(最大内存 32MB,可根据需要调整)
// Hertz 中不需要显式调用 ParseForm,直接访问 MultipartForm 即可
form, err := ctx.Request.MultipartForm()
if err != nil {
return map[string]interface{}{
"error": "failed to parse multipart form: " + err.Error(),
}
}
if form == nil {
return "no file"
}
filesInfo := make([]map[string]interface{}, 0, len(form.File))
for fieldName, headers := range form.File {
for _, fh := range headers {
filesInfo = append(filesInfo, map[string]interface{}{
"field": fieldName,
"name": fh.Filename,
"size": fh.Size,
})
}
}
// 普通表单字段也可以顺便打印(可选)
values := make(map[string][]string)
for k, v := range form.Value {
values[k] = v
}
return map[string]interface{}{
"files": filesInfo,
"values": values,
}
}
// 截断过长字符串
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...(truncated)"
}认证中间件
go
package main
import (
"context"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/protocol/consts"
)
// 简单的认证中间件
func AuthMiddleware() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
token := string(ctx.GetHeader("Authorization"))
if token != "Bearer secret-token" {
ctx.JSON(consts.StatusUnauthorized, map[string]interface{}{
"error": "Invalid or missing token",
})
// 中止请求处理
ctx.Abort()
return
}
// 继续处理请求
ctx.Next(c)
}
}
func main() {
h := server.New()
// 在特定路由上使用中间件
authorized := h.Group("/api", AuthMiddleware())
{
authorized.GET("/data", func(c context.Context, ctx *app.RequestContext) {
ctx.JSON(consts.StatusOK, map[string]interface{}{
"data": "认证后的数据",
})
})
}
// 不需要认证的路由
h.GET("/public", func(c context.Context, ctx *app.RequestContext) {
ctx.JSON(consts.StatusOK, map[string]interface{}{
"data": "公共数据",
})
})
h.Spin()
}打印入参中间件
go
package utils
import (
"encoding/json"
"fmt"
"net/url"
"strings"
"github.com/cloudwego/hertz/pkg/app"
)
// RequestDump 用于结构化输出所有请求参数
type RequestDump struct {
Query map[string][]string `json:"query,omitempty"`
Path map[string]string `json:"path,omitempty"`
Form map[string][]string `json:"form,omitempty"`
JSON map[string]interface{} `json:"json,omitempty"`
Header map[string][]string `json:"header,omitempty"` // 可选
}
// DumpRequestAll 打印 Hertz 请求中的所有参数(调试专用)
func DumpRequestAll(c *app.RequestContext) RequestDump {
var dump RequestDump
// 初始化字段
dump.Query = make(map[string][]string)
dump.Path = make(map[string]string)
dump.Form = make(map[string][]string)
dump.Header = make(map[string][]string)
fmt.Println("=== HERTZ 请求参数完整打印 ===")
contentType := string(c.Request.Header.ContentType())
// 1. ✅ Query 参数(GET 参数)
c.Request.URI().QueryArgs().VisitAll(func(key, value []byte) {
k := string(key)
dump.Query[k] = append(dump.Query[k], string(value))
})
if len(dump.Query) > 0 {
printMapOfSlice(contentType, "Query Params", dump.Query)
}
// 2. ✅ Path 参数(如 /user/:userId)
for _, p := range c.Params {
dump.Path[p.Key] = p.Value
}
if len(dump.Path) > 0 {
printMapOfSlice(contentType, "Path Params", dump.Path)
}
// 3. ✅ Header(可选打印)
c.Request.Header.VisitAll(func(key, value []byte) {
k := string(key)
dump.Header[k] = append(dump.Header[k], string(value))
})
// 注释掉下面这行如果你不想频繁打印 header
// if len(dump.Header) > 0 {
// fmt.Println("Headers:")
// printMapOfSlice(dump.Header, " ")
// }
// 4. ✅ Body: JSON 或 Form
body := c.Request.Body()
if len(body) > 0 {
// 解析
if len(body) > 1024*1024 { // 超过 1MB
fmt.Println("Body too large, skip parsing")
} else {
// 如果是 JSON
if strings.Contains(contentType, "application/json") {
var jsonData map[string]interface{}
if err := json.Unmarshal(body, &jsonData); err == nil {
dump.JSON = jsonData
printMapOfSlice(contentType, "json", dump.JSON)
} else {
fmt.Printf("JSON 解析失败: %v\n", err)
}
} else if strings.Contains(contentType, "x-www-form-urlencoded") { // 如果是 Form
// ✅ 手动解析 form body
formData, err := url.ParseQuery(string(body))
if err == nil {
// 转成 map[string][]string
for k, v := range formData {
dump.Form[k] = v
}
if len(dump.Form) > 0 {
printMapOfSlice(contentType, "form", dump.Form)
}
} else {
fmt.Printf("Form 解析失败: %v\n", err)
}
} else if strings.Contains(contentType, "multipart/form-data") {
// 解析 multipart 表单
form, err := c.Request.MultipartForm()
if err != nil {
fmt.Printf("Failed to parse multipart form: %v\n", err)
} else {
// 遍历所有文本字段(非文件)
for key, values := range form.Value {
// values 是 []string,因为同一个 key 可能出现多次
dump.Form[key] = values
}
printMapOfSlice(contentType, "multipart-form", dump.Form)
}
}
}
}
fmt.Println("===参数打印结束 ===")
return dump
}
// 工具函数:格式化打印 map[string][]string
func printMapOfSlice(contentType string, source string, m any) {
marshal, _ := json.Marshal(m)
fmt.Printf("contentType : %s , source = %s , request param : %s\n", contentType, source, marshal)
}打印接口请求响应耗时中间件
终极版
main 注册
go
package main
import (
"context"
"demo/middleware"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/protocol/consts"
)
func main() {
// 创建一个默认的 Hertz 服务器实例(默认使用 8888 端口)
// 使用 server.Default() 会默认添加一些中间件,例如恢复中间件
// 使用 server.New() 则可以创建一个没有任何默认中间件的纯净实例
h := server.Default()
h.Use(middleware.LogRequestResponse())
// 注册一个 GET 路由处理函数
// 第一个参数是路由路径,第二个参数是处理函数
h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
// 返回 JSON 响应
ctx.JSON(consts.StatusOK, map[string]interface{}{
"message": "pong",
"status": "success",
})
})
// 启动服务器
// Spin 方法会阻塞当前 goroutine,直到服务器被关闭
h.Spin()
}中间件
go
package middleware
import (
"context"
"encoding/json"
"fmt"
"log"
"net/url"
"strings"
"time"
"github.com/cloudwego/hertz/pkg/app"
)
// LogRequestResponse 是一个 Hertz 中间件,用于打印请求/响应和耗时
func LogRequestResponse() app.HandlerFunc {
return func(c context.Context, ctx *app.RequestContext) {
start := time.Now()
// 记录入参
reqLog := DumpRequestAll(ctx)
// 执行业务逻辑
ctx.Next(context.Background())
// === 处理响应体 ===
respBody := ctx.Response.Body()
contentType := string(ctx.Response.Header.ContentType())
var respLog string
if len(respBody) == 0 {
respLog = "{}"
} else if isTextContent(contentType) {
respLog = string(respBody) // 安全:这是 JSON / text
} else {
respLog = fmt.Sprintf("<binary> size=%d type=%s", len(respBody), contentType)
}
cost := time.Since(start)
m := map[string]interface{}{
"log_tye": "[API LOG]",
"method": string(ctx.Method()),
"path": string(ctx.Path()),
"request": reqLog,
"response": respLog,
"cost": cost,
}
jsonData, _ := json.Marshal(m)
log.Printf("api 拦截log : %v", string(jsonData))
}
}
func isTextContent(ct string) bool {
return strings.Contains(ct, "application/json") ||
strings.Contains(ct, "text/") ||
ct == "" // 默认当作文本
}
// RequestDump 用于结构化输出所有请求参数
type RequestDump struct {
Query map[string][]string `json:"query,omitempty"`
Path map[string]string `json:"path,omitempty"`
Form map[string][]string `json:"form,omitempty"`
JSON map[string]interface{} `json:"json,omitempty"`
Header map[string][]string `json:"header,omitempty"` // 可选
}
// DumpRequestAll 打印 Hertz 请求中的所有参数(调试专用)
func DumpRequestAll(c *app.RequestContext) RequestDump {
var dump RequestDump
// 初始化字段
dump.Query = make(map[string][]string)
dump.Path = make(map[string]string)
dump.Form = make(map[string][]string)
dump.Header = make(map[string][]string)
contentType := string(c.Request.Header.ContentType())
// 1. ✅ Query 参数(GET 参数)
c.Request.URI().QueryArgs().VisitAll(func(key, value []byte) {
k := string(key)
dump.Query[k] = append(dump.Query[k], string(value))
})
if len(dump.Query) > 0 {
printMapOfSlice(contentType, "Query Params", dump.Query)
}
// 2. ✅ Path 参数(如 /user/:userId)
for _, p := range c.Params {
dump.Path[p.Key] = p.Value
}
if len(dump.Path) > 0 {
printMapOfSlice(contentType, "Path Params", dump.Path)
}
// 3. ✅ Header(可选打印)
c.Request.Header.VisitAll(func(key, value []byte) {
k := string(key)
dump.Header[k] = append(dump.Header[k], string(value))
})
// 注释掉下面这行如果你不想频繁打印 header
// if len(dump.Header) > 0 {
// fmt.Println("Headers:")
// printMapOfSlice(dump.Header, " ")
// }
// 4. ✅ Body: JSON 或 Form
body := c.Request.Body()
if len(body) > 0 {
// 解析
if len(body) > 1024*1024 { // 超过 1MB
fmt.Println("Body too large, skip parsing")
} else {
// 如果是 JSON
if strings.Contains(contentType, "application/json") {
var jsonData map[string]interface{}
if err := json.Unmarshal(body, &jsonData); err == nil {
dump.JSON = jsonData
printMapOfSlice(contentType, "json", dump.JSON)
} else {
fmt.Printf("JSON 解析失败: %v\n", err)
}
} else if strings.Contains(contentType, "x-www-form-urlencoded") { // 如果是 Form
// ✅ 手动解析 form body
formData, err := url.ParseQuery(string(body))
if err == nil {
// 转成 map[string][]string
for k, v := range formData {
dump.Form[k] = v
}
if len(dump.Form) > 0 {
printMapOfSlice(contentType, "form", dump.Form)
}
} else {
fmt.Printf("Form 解析失败: %v\n", err)
}
} else if strings.Contains(contentType, "multipart/form-data") {
// 解析 multipart 表单
form, err := c.Request.MultipartForm()
if err != nil {
fmt.Printf("Failed to parse multipart form: %v\n", err)
} else {
// 遍历所有文本字段(非文件)
for key, values := range form.Value {
// values 是 []string,因为同一个 key 可能出现多次
dump.Form[key] = values
}
printMapOfSlice(contentType, "multipart-form", dump.Form)
}
}
}
}
return dump
}
// 工具函数:格式化打印 map[string][]string
func printMapOfSlice(contentType string, source string, m any) {
marshal, _ := json.Marshal(m)
fmt.Printf("contentType : %s , source = %s , request param : %s\n", contentType, source, marshal)
}
