goaiagent/provider/openai.go

60 lines
1.4 KiB
Go

package provider
import (
"context"
"strings"
"git.kingecg.top/kingecg/goaiagent/model"
"github.com/sashabaranov/go-openai"
)
// ==============================
// OpenAI 兼容实现
// ==============================
type OpenAiProvider struct {
client *openai.Client
model string
}
func NewOpenAiProvider(apiKey, model string) *OpenAiProvider {
config := openai.DefaultConfig(apiKey)
return &OpenAiProvider{
client: openai.NewClientWithConfig(config),
model: model,
}
}
func (o *OpenAiProvider) CreateChatCompletion(ctx context.Context, req model.ChatRequest) (model.ChatResponse, error) {
// 转换统一请求为OpenAI SDK格式
messages := make([]openai.ChatCompletionMessage, len(req.Messages))
for i, m := range req.Messages {
messages[i] = openai.ChatCompletionMessage{
Role: m.Role,
Content: m.Content,
}
}
resp, err := o.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: o.model,
Messages: messages,
MaxTokens: req.MaxTokens,
})
if err != nil {
return model.ChatResponse{}, err
}
return model.ChatResponse{
Content: resp.Choices[0].Message.Content,
InputTokens: resp.Usage.PromptTokens,
OutputTokens: resp.Usage.CompletionTokens,
}, nil
}
func (o *OpenAiProvider) CountTokens(text string) int {
return len(strings.Fields(text))
}
func (o *OpenAiProvider) GetModelName() string {
return o.model
}