104 lines
2.8 KiB
Go
104 lines
2.8 KiB
Go
package provider
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"git.kingecg.top/kingecg/goaiagent/model"
|
||
"github.com/sashabaranov/go-openai"
|
||
)
|
||
|
||
// ==============================
|
||
// DeepSeek API 实现:cite[1]:cite[7]
|
||
// ==============================
|
||
type DeepSeekProvider struct {
|
||
apiKey string
|
||
model string // deepseek-chat 或 deepseek-reasoner:cite[8]
|
||
endpoint string
|
||
}
|
||
|
||
func NewDeepSeekProvider(apiKey, model string) *DeepSeekProvider {
|
||
return &DeepSeekProvider{
|
||
apiKey: apiKey,
|
||
model: model,
|
||
endpoint: "https://api.deepseek.com/v1/chat/completions", // DeepSeek官方API地址:cite[7]
|
||
}
|
||
}
|
||
func (d *DeepSeekProvider) GetProviderName() string {
|
||
return "DeepSeek"
|
||
}
|
||
|
||
func (d *DeepSeekProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) {
|
||
// 构造DeepSeek专属请求体:cite[1]
|
||
deepseekReq := struct {
|
||
Model string `json:"model"`
|
||
Messages []openai.ChatCompletionMessage `json:"messages"`
|
||
MaxTokens int `json:"max_tokens"`
|
||
}{
|
||
Model: d.model,
|
||
Messages: req.Messages,
|
||
MaxTokens: req.MaxTokens,
|
||
}
|
||
|
||
reqBody, _ := json.Marshal(deepseekReq)
|
||
httpReq, _ := http.NewRequest("POST", d.endpoint, bytes.NewReader(reqBody))
|
||
httpReq.Header.Set("Authorization", "Bearer "+d.apiKey)
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{}
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return model.ChatResponse{}, err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, _ := io.ReadAll(resp.Body)
|
||
|
||
// 解析DeepSeek特有响应格式
|
||
var result struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
Usage struct {
|
||
PromptTokens int `json:"prompt_tokens"`
|
||
CompletionTokens int `json:"completion_tokens"`
|
||
} `json:"usage"`
|
||
}
|
||
|
||
if err := json.Unmarshal(body, &result); err != nil {
|
||
return model.ChatResponse{}, err
|
||
}
|
||
|
||
if len(result.Choices) == 0 {
|
||
return model.ChatResponse{}, fmt.Errorf("empty response from DeepSeek")
|
||
}
|
||
|
||
return model.ChatResponse{
|
||
Content: result.Choices[0].Message.Content,
|
||
InputTokens: result.Usage.PromptTokens,
|
||
OutputTokens: result.Usage.CompletionTokens,
|
||
}, nil
|
||
}
|
||
|
||
func (d *DeepSeekProvider) CreateChatCompletionStream(ctx context.Context, req model.ChatRequest) (*StreamReader, error) {
|
||
streamClient := NewDeepSeekStreamClient(d.apiKey)
|
||
return streamClient.CreateStream(ctx, req.Messages, d.model, req.MaxTokens)
|
||
// return nil, fmt.Errorf("not implemented")
|
||
}
|
||
|
||
// 简化的Token计数器(实际生产应使用tiktoken)
|
||
func (d *DeepSeekProvider) CountTokens(text string) int {
|
||
return len(strings.Fields(text)) // 按空格分词计数
|
||
}
|
||
|
||
func (d *DeepSeekProvider) GetModelName() string {
|
||
return d.model
|
||
}
|