first commit
This commit is contained in:
214
service/search_handler.go
Normal file
214
service/search_handler.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"ai-search/service/logger"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func SearchHandler(c *gin.Context) {
|
||||
searchReq := &SearchReq{}
|
||||
if err := c.Copy().ShouldBindJSON(searchReq); err != nil {
|
||||
ErrResp[gin.H](c, nil, "error", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cachedResult, err := getCacheResp(c, searchReq.SearchUUID)
|
||||
if err == nil {
|
||||
logger.Logger(c).Infof("cache key hit [%s], query: [%s]", searchReq.SearchUUID, searchReq.Query)
|
||||
c.String(http.StatusOK, cachedResult)
|
||||
return
|
||||
}
|
||||
|
||||
if searchReq.Query == "" && searchReq.SearchUUID == "" {
|
||||
ErrResp[gin.H](c, nil, "param is invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if searchReq.Query == "" && searchReq.SearchUUID != "" {
|
||||
ErrResp[gin.H](c, nil, "content is gone", http.StatusGone)
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
cli := NewSearchClient()
|
||||
searchResp, err := cli.Search(c, searchReq.Query, GetSettings().RAGSearchCount)
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Errorf("client.Search error")
|
||||
return
|
||||
}
|
||||
ss := &Sources{}
|
||||
ss.FromSearchResp(&searchResp, searchReq.Query, searchReq.SearchUUID)
|
||||
|
||||
originReq := &openai.ChatCompletionRequest{
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: fmt.Sprintf(RagPrompt(), getSearchContext(ss)),
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: searchReq.Query,
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
apiKey, endpoint := getOpenAIConfig(c, "chat")
|
||||
conf := openai.DefaultConfig(apiKey)
|
||||
conf.BaseURL = endpoint
|
||||
client := openai.NewClientWithConfig(conf)
|
||||
request := openai.ChatCompletionRequest{
|
||||
Model: openai.GPT3Dot5Turbo,
|
||||
Messages: originReq.Messages,
|
||||
Temperature: GetSettings().RAGParams.Temperature,
|
||||
MaxTokens: GetSettings().RAGParams.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletionStream(
|
||||
context.Background(),
|
||||
request,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletionStream error")
|
||||
}
|
||||
|
||||
relatedStrChan := make(chan string)
|
||||
defer close(relatedStrChan)
|
||||
go func() {
|
||||
relatedStrChan <- getRelatedQuestionsResp(c, searchReq.Query, ss)
|
||||
}()
|
||||
|
||||
finalResult := streamSearchItemResp(c, []string{
|
||||
ss.ToString(),
|
||||
"\n\n__LLM_RESPONSE__\n\n",
|
||||
})
|
||||
finalResult = finalResult + streamResp(c, resp)
|
||||
finalResult = finalResult + streamSearchItemResp(c, []string{
|
||||
"\n\n__RELATED_QUESTIONS__\n\n",
|
||||
// `[{"question": "What is the formal way to say hello in Chinese?"}, {"question": "How do you say 'How are you' in Chinese?"}]`,
|
||||
<-relatedStrChan,
|
||||
})
|
||||
|
||||
GetSearchCache().Set([]byte(searchReq.SearchUUID), newCachedResult(searchReq.SearchUUID, searchReq.Query, finalResult).ToBytes(), GetSettings().RAGSearchCacheTime)
|
||||
logger.Logger(c).Infof("cache key miss [%s], query: [%s], set result to cache", searchReq.SearchUUID, searchReq.Query)
|
||||
}
|
||||
|
||||
func getCacheResp(c *gin.Context, searchUUID string) (string, error) {
|
||||
ans, err := GetSearchCache().Get([]byte(searchUUID))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if ans != nil {
|
||||
cachedResult := &cachedResult{}
|
||||
cachedResult.FromBytes(ans)
|
||||
return cachedResult.Result, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("cache not found")
|
||||
}
|
||||
|
||||
func streamSearchItemResp(c *gin.Context, t []string) string {
|
||||
result := ""
|
||||
|
||||
_, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
logger.Logger(c).Panic("server not support")
|
||||
}
|
||||
defer func() {
|
||||
c.Writer.Flush()
|
||||
}()
|
||||
|
||||
for _, line := range t {
|
||||
_, err := c.Writer.WriteString(line)
|
||||
result += line
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Error("write string to client error")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
logger.Logger(c).Info("finish stream text to client")
|
||||
return result
|
||||
}
|
||||
|
||||
func getRelatedQuestionsResp(c context.Context, query string, ss *Sources) string {
|
||||
apiKey, endpoint := getOpenAIConfig(c, "chat")
|
||||
conf := openai.DefaultConfig(apiKey)
|
||||
conf.BaseURL = endpoint
|
||||
client := openai.NewClientWithConfig(conf)
|
||||
request := openai.ChatCompletionRequest{
|
||||
Model: openai.GPT3Dot5Turbo,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: fmt.Sprintf(MoreQuestionsPrompt(), getSearchContext(ss)) + query,
|
||||
},
|
||||
},
|
||||
Temperature: GetSettings().RAGParams.MoreQuestionsTemperature,
|
||||
MaxTokens: GetSettings().RAGParams.MoreQuestionsMaxTokens,
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
request,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletion error")
|
||||
}
|
||||
|
||||
mode := 1
|
||||
|
||||
cs := strings.Split(resp.Choices[0].Message.Content, ". ")
|
||||
if len(cs) == 1 {
|
||||
cs = strings.Split(resp.Choices[0].Message.Content, "- ")
|
||||
mode = 2
|
||||
}
|
||||
rq := []string{}
|
||||
for i, line := range cs {
|
||||
if len(line) <= 2 {
|
||||
continue
|
||||
}
|
||||
if i != len(cs)-1 && mode == 1 {
|
||||
line = line[:len(line)-1]
|
||||
}
|
||||
rq = append(rq, line)
|
||||
}
|
||||
|
||||
return parseRelatedQuestionsResp(rq)
|
||||
}
|
||||
|
||||
func parseRelatedQuestionsResp(qs []string) string {
|
||||
q := []struct {
|
||||
Question string `json:"question"`
|
||||
}{}
|
||||
for _, line := range qs {
|
||||
if len(strings.Trim(line, " ")) <= 2 {
|
||||
continue
|
||||
}
|
||||
q = append(q, struct {
|
||||
Question string `json:"question"`
|
||||
}{line})
|
||||
}
|
||||
rawBytes, _ := json.Marshal(q)
|
||||
return string(rawBytes)
|
||||
}
|
||||
|
||||
func getSearchContext(ss *Sources) string {
|
||||
ans := ""
|
||||
for i, ctx := range ss.Contexts {
|
||||
ans = ans + fmt.Sprintf("[[citation:%d]] ", i+1) + ctx.Snippet + "\n\n"
|
||||
}
|
||||
return ans
|
||||
}
|
||||
Reference in New Issue
Block a user