first commit
This commit is contained in:
129
service/helper.go
Normal file
129
service/helper.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"ai-search/service/logger"
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func getOpenAIConfig(c context.Context, mode string) (token, endpoint string) {
|
||||
token = GetSettings().OpenAIAPIKey
|
||||
endpoint = GetSettings().OpenAIEndpint
|
||||
if mode == "chat" {
|
||||
token = GetSettings().OpenAIChatAPIKey
|
||||
endpoint = GetSettings().OpenAIChatEndpoint
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type Queue struct {
|
||||
data []time.Duration
|
||||
length int
|
||||
}
|
||||
|
||||
func NewQueue(length int, defaultValue int) *Queue {
|
||||
data := make([]time.Duration, 0, length)
|
||||
for i := 0; i < length; i++ {
|
||||
data = append(data, time.Duration(defaultValue)*time.Millisecond)
|
||||
}
|
||||
return &Queue{
|
||||
data: data,
|
||||
length: length,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Queue) Add(value time.Duration) {
|
||||
if len(q.data) >= q.length {
|
||||
q.data = q.data[1:]
|
||||
}
|
||||
q.data = append(q.data, value)
|
||||
}
|
||||
|
||||
func (q *Queue) Avg(k ...int) time.Duration {
|
||||
param := 1
|
||||
if len(k) > 0 {
|
||||
param = k[0]
|
||||
}
|
||||
total := time.Duration(0)
|
||||
count := 0
|
||||
for _, value := range q.data {
|
||||
if value != 0 {
|
||||
total += value
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count == 0 {
|
||||
return time.Duration(0)
|
||||
}
|
||||
ans := total * time.Duration(param) / time.Duration(count)
|
||||
if ans > time.Duration(20)*time.Millisecond {
|
||||
return time.Duration(20) * time.Millisecond
|
||||
}
|
||||
return ans
|
||||
}
|
||||
|
||||
func streamResp(c *gin.Context, resp *openai.ChatCompletionStream) string {
|
||||
result := ""
|
||||
|
||||
if resp == nil {
|
||||
logger.Logger(c).Error("stream resp is nil")
|
||||
return result
|
||||
}
|
||||
|
||||
_, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
logger.Logger(c).Panic("server not support")
|
||||
}
|
||||
defer func() {
|
||||
c.Writer.Flush()
|
||||
}()
|
||||
|
||||
queue := NewQueue(GetSettings().OpenAIChatQueueLen, GetSettings().OpenAIChatNetworkDelay)
|
||||
ch := make(chan rune, 1024)
|
||||
|
||||
go func(c *gin.Context, msgChan chan rune) {
|
||||
lastTime := time.Now()
|
||||
for {
|
||||
line, err := resp.Recv()
|
||||
if err != nil {
|
||||
close(msgChan)
|
||||
logger.Logger(c).WithError(err).Error("read openai completion line error")
|
||||
return
|
||||
}
|
||||
if len(line.Choices[0].Delta.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
nowTime := time.Now()
|
||||
|
||||
division := strings.Count(line.Choices[0].Delta.Content, "")
|
||||
for _, v := range line.Choices[0].Delta.Content {
|
||||
msgChan <- v
|
||||
}
|
||||
|
||||
during := (nowTime.Sub(lastTime) + (time.Duration(GetSettings().OpenAIChatNetworkDelay) *
|
||||
time.Millisecond)) / time.Duration(division)
|
||||
queue.Add(during)
|
||||
lastTime = nowTime
|
||||
}
|
||||
}(c, ch)
|
||||
|
||||
for char := range ch {
|
||||
str := string(char)
|
||||
_, err := c.Writer.WriteString(str)
|
||||
result += str
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Error("write string to client error")
|
||||
return ""
|
||||
}
|
||||
c.Writer.Flush()
|
||||
time.Sleep(queue.Avg(len(str) * 2)) // 英文平均长度为6个字符,一个UTF8字符是3个长度,试图让一个单词等于一个汉字
|
||||
}
|
||||
logger.Logger(c).Info("finish stream text to client")
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user