package provider import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "sync" "time" "github.com/sashabaranov/go-openai" ) // ============================== // 流式读取器核心实现 // ============================== // StreamEvent 表示单个流式事件 type StreamEvent struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []struct { Delta struct { Content string `json:"content"` } `json:"delta"` Index int `json:"index"` FinishReason string `json:"finish_reason"` } `json:"choices"` } // StreamReader 处理流式API响应 type StreamReader struct { scanner *bufio.Scanner response *http.Response currentEvent *StreamEvent events chan StreamEvent errors chan error done chan struct{} closeOnce sync.Once mu sync.Mutex totalTokens int startTime time.Time lastEventTime time.Time } // NewStreamReader 创建新的流式读取器 func NewStreamReader(response *http.Response) *StreamReader { sr := &StreamReader{ response: response, scanner: bufio.NewScanner(response.Body), events: make(chan StreamEvent, 100), errors: make(chan error, 1), done: make(chan struct{}), startTime: time.Now(), lastEventTime: time.Now(), } sr.scanner.Split(sr.splitSSE) // 使用自定义分隔函数 go sr.processStream() return sr } // splitSSE 自定义扫描函数用于处理SSE格式 func (sr *StreamReader) splitSSE(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } // 查找事件分隔符: \n\n 或 \r\n\r\n if i := bytes.Index(data, []byte("\n\n")); i >= 0 { return i + 2, data[0:i], nil } if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 { return i + 4, data[0:i], nil } // 如果到达文件末尾,返回剩余数据 if atEOF { return len(data), data, nil } // 请求更多数据 return 0, nil, nil } // processStream 处理流数据 func (sr *StreamReader) processStream() { defer close(sr.events) defer close(sr.errors) defer sr.response.Body.Close() for sr.scanner.Scan() { eventData := sr.scanner.Bytes() sr.lastEventTime = time.Now() // 跳过空行和注释 if len(eventData) == 0 || bytes.HasPrefix(eventData, []byte(":")) { continue } // 检查是否为 [DONE] 事件 if bytes.Contains(eventData, []byte("[DONE]")) { return } // 提取事件数据部分 prefix := []byte("data: ") if bytes.HasPrefix(eventData, prefix) { eventData = bytes.TrimPrefix(eventData, prefix) } var event StreamEvent if err := json.Unmarshal(eventData, &event); err != nil { sr.errors <- fmt.Errorf("解析事件失败: %w\n原始数据: %s", err, string(eventData)) return } // 更新Token计数 if content := event.Choices[0].Delta.Content; content != "" { sr.mu.Lock() sr.totalTokens += len(strings.Fields(content)) sr.mu.Unlock() } // 发送事件 sr.events <- event } if err := sr.scanner.Err(); err != nil { sr.errors <- fmt.Errorf("扫描错误: %w", err) } } // Recv 接收下一个事件 func (sr *StreamReader) Recv() (StreamEvent, error) { select { case event, ok := <-sr.events: if !ok { return StreamEvent{}, io.EOF } return event, nil case err := <-sr.errors: return StreamEvent{}, err case <-sr.done: return StreamEvent{}, io.EOF } } // Close 关闭流 func (sr *StreamReader) Close() error { sr.closeOnce.Do(func() { close(sr.done) sr.response.Body.Close() }) return nil } // Stats 获取流统计信息 func (sr *StreamReader) Stats() string { sr.mu.Lock() defer sr.mu.Unlock() duration := time.Since(sr.startTime) return fmt.Sprintf( "Tokens: %d | 持续时间: %s | 最后事件: %s前", sr.totalTokens, duration.Round(time.Second), time.Since(sr.lastEventTime).Round(time.Second), ) } // ============================== // DeepSeek流式API客户端 // ============================== type DeepSeekStreamClient struct { apiKey string endpoint string } func NewDeepSeekStreamClient(apiKey string) *DeepSeekStreamClient { return &DeepSeekStreamClient{ apiKey: apiKey, endpoint: "https://api.deepseek.com/v1/chat/completions", } } func (c *DeepSeekStreamClient) CreateStream( ctx context.Context, messages []openai.ChatCompletionMessage, model string, maxTokens int, ) (*StreamReader, error) { // 构造请求体 requestBody := struct { Model string `json:"model"` Messages []openai.ChatCompletionMessage `json:"messages"` MaxTokens int `json:"max_tokens"` Stream bool `json:"stream"` }{ Model: model, Messages: messages, MaxTokens: maxTokens, Stream: true, } body, _ := json.Marshal(requestBody) req, _ := http.NewRequestWithContext(ctx, "POST", c.endpoint, bytes.NewReader(body)) req.Header.Set("Authorization", "Bearer "+c.apiKey) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") // 发送请求 client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, err } // 检查响应状态 if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("API错误: %s, 响应: %s", resp.Status, string(body)) } return NewStreamReader(resp), nil } // ============================== // 使用示例 // ============================== // func main() { // apiKey := "your_deepseek_api_key" // 替换为你的API密钥 // client := NewDeepSeekStreamClient(apiKey) // // 创建对话消息 // messages := []openai.ChatCompletionMessage{ // { // Role: openai.ChatMessageRoleSystem, // Content: "你是一个乐于助人的AI助手", // }, // { // Role: openai.ChatMessageRoleUser, // Content: "请用中文详细解释量子计算的基本原理,以及它为什么比传统计算机更有优势?", // }, // } // // 创建流式请求 // ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // defer cancel() // stream, err := client.CreateStream(ctx, messages, "deepseek-reasoner", 1000) // if err != nil { // fmt.Println("创建流失败:", err) // return // } // defer stream.Close() // fmt.Println("流式响应开始:") // fmt.Println("====================") // var fullResponse strings.Builder // lastPrint := time.Now() // const updateInterval = 100 * time.Millisecond // for { // event, err := stream.Recv() // if errors.Is(err, io.EOF) { // fmt.Println("\n\n====================") // fmt.Println("流式响应结束") // break // } // if err != nil { // fmt.Println("\n接收错误:", err) // break // } // if len(event.Choices) > 0 { // content := event.Choices[0].Delta.Content // if content != "" { // fullResponse.WriteString(content) // // 流式输出控制:定期打印新内容 // if time.Since(lastPrint) > updateInterval { // fmt.Print(content) // lastPrint = time.Now() // } // } // } // } // fmt.Println("\n完整响应:") // fmt.Println(fullResponse.String()) // fmt.Println("\n统计信息:", stream.Stats()) // }