Skip to content
鼓励作者:欢迎打赏犒劳

go整合SSE

go原生+sse的demo

web框架是原生的,还包括一个proxy层,稍微有点复杂,但是绝对是很实用的需求。需求场景是 前端调后端sse,后端需要再调取一个sse接口.

下面的代码完美的展示了上面的需求。

大概得过程。启动go服务,这里会启动2个服务,一个是上游的sse服务,一个是提供前端接口的sse服务。

前端调用 http://localhost:8080/chat , 然后接口调用上游,上游返回sse数据,下游将sse数据也返回给前端 完成闭环

用标准 http.ResponseWriter 替换 app.RequestContext,使代码不依赖特定的 Web 框架(如 Hertz/Gin 的特定封装),更容易移植和理解。如果你必须用 Hertz,只需将 w http.ResponseWriter 改回 c *app.RequestContext 并调用 c.Writer 即可。

后端代码

go
package main

import (
	"bufio"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"log"
	"net/http"
	"strings"
	"time"
)

// --- 配置区域 ---

// 允许的源列表 (生产环境建议配置为具体域名,不要使用 *)
// "*" 表示允许所有域名,但在携带 Cookie 时不能用 "*"
var allowedOrigins = []string{
	"*",
	// "http://localhost:3000",
	// "https://your-frontend.com",
}

// --- 核心逻辑 ---

// corsMiddleware 跨域中间件
func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		origin := r.Header.Get("Origin")

		// 1. 检查 Origin 是否合法 (如果配置了具体白名单)
		// 这里为了演示简单,如果配置了 "*" 则直接通过,实际生产建议校验白名单
		allowOrigin := ""
		for _, o := range allowedOrigins {
			if o == "*" || o == origin {
				allowOrigin = o
				break
			}
		}

		// 如果没有匹配到且不是 "*",则拒绝 (这里简化处理,默认允许)
		if allowOrigin == "" && len(allowedOrigins) > 0 && allowedOrigins[0] != "*" {
			http.Error(w, "CORS policy violation", http.StatusForbidden)
			return
		}

		// 如果配置是 "*",响应头也设 "*" (注意:如果用了 Allow-Credentials,这里不能是 "*")
		if allowOrigin == "" {
			allowOrigin = "*"
		}

		// 2. 设置通用 CORS 响应头
		w.Header().Set("Access-Control-Allow-Origin", allowOrigin)
		w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
		w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
		w.Header().Set("Access-Control-Max-Age", "86400") // 预检结果缓存 24 小时

		// 如果是 SSE,通常需要允许凭证 (如果前端需要传 Cookie)
		// 注意:如果 Allow-Origin 是 "*",则 Allow-Credentials 必须为 false (或省略)
		if allowOrigin != "*" {
			w.Header().Set("Access-Control-Allow-Credentials", "true")
		}

		// 3. 处理预检请求 (OPTIONS)
		if r.Method == http.MethodOptions {
			w.WriteHeader(http.StatusNoContent)
			return
		}

		// 4. 继续处理正常请求
		next(w, r)
	}
}

// StreamProxy 处理上游 SSE 流并转发给下游客户端
func StreamProxy(ctx context.Context, w http.ResponseWriter, upstreamURL string, msgID int64) error {
	req, err := http.NewRequestWithContext(ctx, "GET", upstreamURL, nil)
	if err != nil {
		return fmt.Errorf("failed to create upstream request: %w", err)
	}
	req.Header.Set("Accept", "text/event-stream")

	client := &http.Client{Timeout: 30 * time.Second}
	resp, err := client.Do(req)
	if err != nil {
		return fmt.Errorf("failed to call upstream service: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		body, _ := io.ReadAll(resp.Body)
		return fmt.Errorf("upstream returned status %d: %s", resp.StatusCode, string(body))
	}

	// 设置 SSE 响应头 (CORS 头已经在中间件设置过了)
	w.Header().Set("Content-Type", "text/event-stream")
	w.Header().Set("Cache-Control", "no-cache")
	w.Header().Set("Connection", "keep-alive")
	w.Header().Set("X-Accel-Buffering", "no")

	flusher, ok := w.(http.Flusher)
	if !ok {
		return fmt.Errorf("streaming unsupported")
	}

	var collectedData []string
	reader := bufio.NewReader(resp.Body)

	for {
		select {
		case <-ctx.Done():
			log.Printf("Client disconnected, stopping stream for msgID=%d", msgID)
			return ctx.Err()
		default:
		}

		line, err := reader.ReadBytes('\n')
		if len(line) == 0 && err != nil {
			if err == io.EOF {
				break
			}
			return fmt.Errorf("read error: %w", err)
		}

		lineStr := strings.TrimSpace(string(line))
		if lineStr == "" {
			continue
		}
		// ================= [修复点开始] =================
		// SSE 格式通常是 "data: <payload>"
		// 我们需要去掉 "data: " 前缀才能解析后面的 JSON
		jsonStr := lineStr
		prefix := "data: "

		if strings.HasPrefix(lineStr, prefix) {
			jsonStr = strings.TrimPrefix(lineStr, prefix)
		} else if strings.HasPrefix(lineStr, "data:") {
			// 兼容没有空格的情况 "data:<payload>"
			jsonStr = strings.TrimPrefix(lineStr, "data:")
		}
		// 如果去掉前缀后还是空的,跳过
		if jsonStr == "" {
			continue
		}
		// ================= [修复点结束] =================
		var upMsg UpstreamMsg
		if err := json.Unmarshal([]byte(jsonStr), &upMsg); err != nil {
			log.Printf("Error unmarshalling JSON: %v, data: %s", err, lineStr)
			continue
		}

		if upMsg.Type == "task_completed" {
			log.Printf("Task completed for msgID=%d", msgID)
			break
		}

		downMsg := DownstreamMsg{
			Type: upMsg.Type,
			Data: upMsg.Data,
			ID:   msgID,
		}

		collectedData = append(collectedData, upMsg.Data)

		outBytes, err := json.Marshal(downMsg)
		if err != nil {
			log.Printf("Error marshalling response: %v", err)
			continue
		}

		// 发送 SSE 事件
		_, err = fmt.Fprintf(w, "data: %s\n\n", outBytes)
		if err != nil {
			return fmt.Errorf("failed to write to client: %w", err)
		}
		flusher.Flush()

		if upMsg.Type == "error" {
			log.Printf("Upstream reported error: %s", upMsg.Data)
			return fmt.Errorf("upstream_error: %s", upMsg.Data)
		}
	}

	fullDataJSON, _ := json.Marshal(collectedData)
	log.Printf("Stream finished. Full data collected: %s", string(fullDataJSON))
	return nil
}

// HandleChatStream HTTP 入口
func HandleChatStream(w http.ResponseWriter, r *http.Request) {
	ctx := r.Context()
	msgID := time.Now().UnixNano()
	upstreamURL := "http://localhost:8081/mock-sse"

	log.Printf("Starting stream proxy for msgID=%d from origin: %s", msgID, r.Header.Get("Origin"))

	err := StreamProxy(ctx, w, upstreamURL, msgID)
	if err != nil {
		if strings.Contains(err.Error(), "upstream_error") {
			errMsg := DownstreamMsg{Type: "error", Data: err.Error()[14:], ID: msgID}
			b, _ := json.Marshal(errMsg)
			fmt.Fprintf(w, "data: %s\n\n", b)
			if f, ok := w.(http.Flusher); ok {
				f.Flush()
			}
		} else if err != context.Canceled {
			log.Printf("Stream proxy failed: %v", err)
		}
	}
}

// --- 数据结构 ---

type UpstreamMsg struct {
	Type string `json:"type"`
	Data string `json:"data"`
	ID   int64  `json:"id,omitempty"`
}

type DownstreamMsg struct {
	Type string `json:"type"`
	Data string `json:"data"`
	ID   int64  `json:"id"`
}

// --- Mock 上游服务 ---

func MockUpstreamSSE(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "text/event-stream")
	w.Header().Set("Cache-Control", "no-cache")
	w.Header().Set("Connection", "keep-alive")

	flusher, _ := w.(http.Flusher)

	steps := []UpstreamMsg{
		{Type: "thinking", Data: "正在思考..."},
		{Type: "content", Data: "Hello"},
		{Type: "content", Data: " from"},
		{Type: "content", Data: " Cross-Origin"},
		{Type: "content", Data: " Server!"},
		{Type: "task_completed", Data: ""},
	}

	for _, msg := range steps {
		data, _ := json.Marshal(msg)
		fmt.Fprintf(w, "data: %s\n\n", data)
		flusher.Flush()
		time.Sleep(500 * time.Millisecond)
	}
}

// --- Main ---

func main() {
	// 启动 Mock 上游服务
	go func() {
		http.HandleFunc("/mock-sse", MockUpstreamSSE)
		log.Println("Mock Upstream Server starting on :8081")
		if err := http.ListenAndServe(":8081", nil); err != nil {
			log.Fatal("Mock Upstream Server failed:", err)
		}
	}()

	// 启动主代理服务 (应用 CORS 中间件)
	// 注意:我们将 HandleChatStream 包裹在 corsMiddleware 中
	http.HandleFunc("/chat", corsMiddleware(HandleChatStream))

	log.Println("Proxy Server with CORS starting on :8080")
	log.Println("Test URL: http://localhost:8080/chat")
	log.Println("Allowed Origins:", allowedOrigins)

	if err := http.ListenAndServe(":8080", nil); err != nil {
		log.Fatal("Proxy Server failed:", err)
	}
}

前端代码

html
<!DOCTYPE html>
<head>
    <meta charset="UTF-8">
    <title>Title</title>
</head>
<body>
<h1>SSE Test</h1>
<div id="output"></div>
<script>
    const evtSource = new EventSource("http://localhost:8080/chat");
    const output = document.getElementById("output");

    // 监听消息
    evtSource.onmessage = function(event) {
        try {
            const data = JSON.parse(event.data);
            const p = document.createElement("p");
            p.textContent = `[${data.type}] (ID:${data.id}) ${data.data}`;
            output.appendChild(p);

            // 如果收到完成信号,可以在这里选择立即关闭,也可以等 onerror 再关
            if (data.type === 'task_completed') {
                console.log("✅ 收到完成信号,等待连接关闭...");
                // 注意:这里不要立即 close(),因为可能还有最后一个包在传输中
                // 我们依靠下面的 onerror 来统一处理关闭逻辑
            }
        } catch (e) {
            console.error("JSON 解析失败:", e);
        }
    };

    // 监听错误/断开
    evtSource.onerror = function(err) {
        // 关键步骤:无论什么情况,一旦触发 error,立刻手动关闭!
        // 这会阻止浏览器的自动重连机制,打破死循环
        evtSource.close();

        const state = evtSource.readyState;
        // readyState: 0=CONNECTING, 1=OPEN, 2=CLOSED
        // 调用 close() 后,状态应该变为 2

        if (state === 2) {
            console.log("✅ 连接已正常关闭 (ReadyState: CLOSED),不再重连。");
        } else {
            // 理论上调用 close() 后应该是 2,如果是其他状态说明还在挣扎
            console.warn("⚠️ 连接断开,已强制停止重连。");
        }
    };
</script>
</body>

问题

下面的代码,你可能会有一些疑问。

  1. 前端,server客户端是如何知道sse接口输出完毕了,没看到相关结束的代码。 答:SSE 协议本身确实没有一个专门的“结束包”(比如 event: end),它的结束机制完全依赖于 HTTP 协议的底层行为。

1. 客户端(浏览器)如何知道结束了?

答案:靠“连接断开” (TCP FIN)。

浏览器并不“知道”数据是否发完了,它只知道服务器把电话挂了

  • 过程
    1. 后端 Go 代码执行完 for 循环,函数返回。
    2. Go 的 net/http 库检测到 Handler 函数结束,会自动向客户端发送 TCP FIN 包(关闭连接)。
    3. 浏览器的 EventSource 对象监听到底层 TCP 连接关闭。
    4. 判定逻辑
      • 如果连接意外断开(如网线拔了),触发 onerror 并尝试重连。
      • 如果连接是正常关闭(收到 FIN),依然触发 onerror
    5. 关键点:浏览器无法区分“网络故障”和“任务完成”。这就是为什么必须在前端代码里手动调用 evtSource.close() 来阻止重连的原因。

比喻:就像打电话。对方说完了最后一句话,直接挂断电话(嘟嘟嘟...)。你听到忙音(连接断开),就知道对方不会再说话了。但你不知道他是“说完了”还是“手机没电了”,除非你们之前约定好“说完‘再见’就挂电话”。


2. 后端(Go 代理)如何知道上游不再输出数据了?

答案:靠 io.EOF (End Of File)。

在你的代码中,这个逻辑隐藏在 bufio.Reader 的读取循环里。请看这段核心代码:

go
for {
    // ...
    
    // 1. 尝试读取一行
    line, err := reader.ReadBytes('\n')
    
    // 2. 检查错误
    if len(line) == 0 && err != nil {
        if err == io.EOF { 
            // 👉 关键在这里!
            // 当上游服务器关闭连接时,ReadBytes 会返回 io.EOF
            break // 跳出循环,结束处理
        }
        return fmt.Errorf("read error: %w", err)
    }
    
    // ... 处理数据 ...
}

详细流程解析

  1. 上游 Mock 服务 (MockUpstreamSSE)

    • 它有一个 for _, msg := range steps 循环。
    • 当循环跑完(发完 "Server!" 和 "task_completed"),函数执行结束。
    • Go 的 http.ResponseWriter 在函数结束时,自动关闭与代理服务的连接。
  2. 代理服务 (StreamProxy)

    • 它正在 reader.ReadBytes('\n') 这里阻塞等待数据。
    • 突然,上游关闭了连接。
    • ReadBytes 立刻返回。此时 err 的值是 io.EOF(表示流已尽)。
    • 代码检测到 err == io.EOF,执行 break
    • StreamProxy 函数的 for 循环结束,函数返回。
    • 代理服务随之关闭与前端的连接。

3. select { case <-ctx.Done(): ... } 什么意思? 会一直卡到这么

select和 <- 是配套使用的,因为 <- 是从通道里面拿东西,是一个阻塞的行为,只能用select,你不能用if switch。

也就是说每次循环,都会执行一次select,判断本次会话是不是关闭了。重点来了,必须要有default才行,才会走下面的逻辑代码,不然会一直阻塞在 <-

go
for {
    // 【关键点】每次循环开始前,先检查一下“客户还在吗?”
    select {
    case <-ctx.Done():
        // 如果客户断了,这里立刻执行
        log.Println("客户走了,我不读了!")
        return ctx.Err() // 直接退出函数,释放资源
    default:
        // 如果客户还在,default 分支立刻执行,不会阻塞,继续往下读数据
    }

    // 读取上游数据...
    line, err := reader.ReadBytes('\n')
    // ...处理数据...
}

如有转载或 CV 的请标注本站原文地址