8000 Feat/rag by vaayne · Pull Request #9 · recally-io/recally · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Feat/rag #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ require (
github.com/pgvector/pgvector-go v0.2.0
github.com/riverqueue/river v0.9.0
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.9.0
github.com/sashabaranov/go-openai v1.27.0
github.com/sashabaranov/go-openai v1.30.3
github.com/stretchr/testify v1.9.0
github.com/swaggo/echo-swagger v1.4.1
github.com/swaggo/swag v1.16.3
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,8 @@ github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/sagikazarmark/crypt v0.6.0/go.mod h1:U8+INwJo3nBv1m6A/8OBXAq7Jnpspk5AxSgDyEQcea8=
github.com/sashabaranov/go-openai v1.27.0 h1:L3hO6650YUbKrbGUC6yCjsUluhKZ9h1/jcgbTItI8Mo=
github.com/sashabaranov/go-openai v1.27.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.30.3 h1:TEdRP3otRXX2A7vLoU+kI5XpoSo7VUUlM/rEttUqgek=
github.com/sashabaranov/go-openai v1.30.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/shirou/gopsutil/v3 v3.23.12 h1:z90NtUkp3bMtmICZKpC4+WaknU1eXtp5vtbQ11DgpE4=
github.com/shirou/gopsutil/v3 v3.23.12/go.mod h1:1FrWgea594Jp7qmjHUUPlJDTPgcsb9mGnXDxavtikzM=
Expand Down
10 changes: 9 additions & 1 deletion internal/core/assistants/assistant_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)

type RagSettings struct {
Enable bool `json:"enable"`
MultiQuery bool `json:"multi_query"`
QueryRewrite bool `json:"query_rewrite"`
Rerank bool `json:"rerank"`
}

type AssistantMetadata struct {
// Tools is a list of tools that the assistant can use
Tools []string `json:"tools"`
Tools []string `json:"tools,omitempty"`
RagSettings RagSettings `json:"rag_settings,omitempty"`
}

type AssistantDTO struct {
Expand Down
4 8000 changes: 4 additions & 0 deletions internal/core/assistants/assistant_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ func (s *Service) GetAssistant(ctx context.Context, tx db.DBTX, id uuid.UUID) (*
}

func (s *Service) DeleteAssistant(ctx context.Context, tx db.DBTX, assistantId uuid.UUID) error {
// Delete associated attachments
if err := s.dao.DeleteAssistantAttachmentsByAssistantId(ctx, tx, pgtype.UUID{Bytes: assistantId, Valid: true}); err != nil {
return fmt.Errorf("failed to delete assistant attachments: %w", err)
}
// Delete associated threads and messages
if err := s.dao.DeleteThreadMessagesByAssistant(ctx, tx, pgtype.UUID{Bytes: assistantId, Valid: true}); err != nil {
return fmt.Errorf("failed to delete thread messages by assistant: %w", err)
Expand Down
28 changes: 28 additions & 0 deletions internal/core/assistants/attachment_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@ func (s *Service) CreateAttachment(ctx context.Context, tx db.DBTX, attachment *
}

attachment.Load(&ast)

if attachment.ThreadId != uuid.Nil {
// update thread to enbale rag
t, err := s.GetThread(ctx, tx, attachment.ThreadId)
if err != nil {
logger.FromContext(ctx).Error("failed to get thread", "err", err)
return attachment, nil
}
if !t.Metadata.RagSettings.Enable {
t.Metadata.RagSettings.Enable = true
if _, err := s.UpdateThread(ctx, tx, t); err != nil {
logger.FromContext(ctx).Error("failed to update thread", "err", err)
}
}
} else {
// update assistant to enbale rag
a, err := s.GetAssistant(ctx, tx, attachment.AssistantId)
if err != nil {
logger.FromContext(ctx).Error("failed to get assistant", "err", err)
return attachment, nil
}
if !a.Metadata.RagSettings.Enable {
a.Metadata.RagSettings.Enable = true
if _, err := s.UpdateAssistant(ctx, tx, a); err != nil {
logger.FromContext(ctx).Error("failed to update assistant", "err", err)
}
}
}
return attachment, nil
}

Expand Down
8 changes: 5 additions & 3 deletions internal/core/assistants/message_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"time"
"vibrain/internal/pkg/db"
"vibrain/internal/pkg/llms"
"vibrain/internal/pkg/logger"

"github.com/google/uuid"
Expand All @@ -12,9 +13,10 @@ import (
)

type MessageMetadata struct {
Tools []string `json:"tools"`
Images []string `json:"images"`
Stream bool `json:"stream"`
Tools []string `json:"tools"`
Images []string `json:"images"`
Stream bool `json:"stream"`
IntermediateSteps []llms.IntermediateStep `json:"intermediate_steps"`
}

// 1563 dimensions
Expand Down
23 changes: 21 additions & 2 deletions internal/core/assistants/thread_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,27 @@ import (
)

type ThreadMetadata struct {
IsGeneratedTitle bool `json:"is_generated_title"`
Tools []string `json:"tools"`
AssistantMetadata
IsGeneratedTitle bool `json:"is_generated_title"`
}

// Merge merges the thread metadata with the assistant metadata
func (m *ThreadMetadata) Merge(am AssistantMetadata) {
mergedRagSettings := am.RagSettings // Start with assistant's RagSettings

// Override assistant's RagSettings if thread has non-default values
mergedRagSettings.Enable = m.RagSettings.Enable
mergedRagSettings.MultiQuery = m.RagSettings.MultiQuery
mergedRagSettings.QueryRewrite = m.RagSettings.QueryRewrite
mergedRagSettings.Rerank = m.RagSettings.Rerank

mergedTools := am.Tools // Take assistant tools first
if len(m.Tools) > 0 {
mergedTools = m.Tools // Override if thread has its own tools
}

m.Tools = mergedTools
m.RagSettings = mergedRagSettings
}

type ThreadDTO struct {
Expand Down
83 changes: 62 additions & 21 deletions internal/core/assistants/thread_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ func (s *Service) GetThread(ctx context.Context, tx db.DBTX, id uuid.UUID) (*Thr
var t ThreadDTO
t.Load(&th)

ass, err := s.dao.GetAssistant(ctx, tx, th.AssistantID.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to get assistant: %w", err)
}

var a AssistantDTO
a.Load(&ass)
t.Metadata.Merge(a.Metadata)

// messages, err := s.ListThreadMessages(ctx, tx, th.Uuid)
// if err != nil {
// return nil, fmt.Errorf("failed to get thread messages: %w", err)
Expand All @@ -101,6 +110,10 @@ func (s *Service) GetThread(ctx context.Context, tx db.DBTX, id uuid.UUID) (*Thr
}

func (s *Service) DeleteThread(ctx context.Context, tx db.DBTX, id uuid.UUID) error {
// Delete associated attachments
if err := s.dao.DeleteAssistantAttachmentsByThreadId(ctx, tx, pgtype.UUID{Bytes: id, Valid: true}); err != nil {
return fmt.Errorf("failed to delete thread attachments: %w", err)
}
if err := s.dao.DeleteThreadMessagesByThread(ctx, tx, pgtype.UUID{Bytes: id, Valid: true}); err != nil {
return fmt.Errorf("failed to delete thread messages: %w", err)
}
Expand All @@ -117,22 +130,12 @@ func (s *Service) RunThread(ctx context.Context, tx db.DBTX, id uuid.UUID, strea
return
}

oaiMessages, model, metadata, err := s.buildChatMessages(ctx, tx, thread)
if err != nil {
streamingFunc(nil, fmt.Errorf("failed to build chat messages: %w", err))
return
}

opts := []llms.Option{
llms.WithModel(model),
llms.WithToolNames(metadata.Tools),
llms.WithStream(metadata.Stream),
}

var newMessage *MessageDTO
newMessageID := uuid.New()
sb := strings.Builder{}
var usage *openai.Usage
model := thread.Model
intermediateSteps := make([]llms.IntermediateStep, 0)

sendToUser := func(streamMsg llms.StreamingMessage) {
choice := streamMsg.Choice
Expand All @@ -141,13 +144,20 @@ func (s *Service) RunThread(ctx context.Context, tx db.DBTX, id uuid.UUID, strea
streamingFunc(nil, err)
return
}
if choice == nil {
// streamingFunc(nil, fmt.Errorf("no content generated"))
return

if len(streamMsg.IntermediateSteps) > 0 {
intermediateSteps = append(intermediateSteps, streamMsg.IntermediateSteps...)
}

if streamMsg.Usage != nil {
usage = streamMsg.Usage
}

if choice == nil {
// streamingFunc(nil, fmt.Errorf("no content generated"))
return
}

sb.WriteString(choice.Message.Content)
newMessage = &MessageDTO{
ID: newMessageID,
Expand All @@ -157,6 +167,9 @@ func (s *Service) RunThread(ctx context.Context, tx db.DBTX, id uuid.UUID, strea
Model: model,
Role: choice.Message.Role,
Text: choice.Message.Content,
Metadata: MessageMetadata{
IntermediateSteps: intermediateSteps,
},
// PromptToken: int32(usage.PromptTokens),
// CompletionToken: int32(usage.CompletionTokens),
}
Expand All @@ -168,6 +181,20 @@ func (s *Service) RunThread(ctx context.Context, tx db.DBTX, id uuid.UUID, strea
streamingFunc(newMessage, nil)
}

oaiMessages, lmodel, metadata, steps, err := s.buildChatMessages(ctx, tx, thread)
if err != nil {
streamingFunc(nil, fmt.Errorf("failed to build chat messages: %w", err))
return
}
model = lmodel
intermediateSteps = append(intermediateSteps, steps...)

opts := []llms.Option{
llms.WithModel(model),
llms.WithToolNames(metadata.Tools),
llms.WithStream(metadata.Stream),
}

s.llm.GenerateContent(ctx, oaiMessages, sendToUser, opts...)
if newMessage != nil {
newMessage.Text = sb.String()
Expand Down Expand Up @@ -225,12 +252,13 @@ func (s *Service) GenerateThreadTitle(ctx context.Context, tx db.DBTX, id uuid.U
return title, nil
}

func (s *Service) buildChatMessages(ctx context.Context, tx db.DBTX, thread *ThreadDTO) ([]openai.ChatCompletionMessage, string, MessageMetadata, error) {
func (s *Service) buildChatMessages(ctx context.Context, tx db.DBTX, thread *ThreadDTO) ([]openai.ChatCompletionMessage, string, MessageMetadata, []llms.IntermediateStep, error) {
oaiMessages := make([]openai.ChatCompletionMessage, 0)
messages, err := s.ListThreadMessages(ctx, tx, thread.Id)
metadata := messages[len(messages)-1].Metadata
steps := make([]llms.IntermediateStep, 0)
if err != nil {
return nil, "", metadata, fmt.Errorf("failed to get thread messages: %w", err)
return nil, "", metadata, steps, fmt.Errorf("failed to get thread messages: %w", err)
}
oaiMessages = append(oaiMessages, openai.ChatCompletionMessage{
Role: "system",
Expand All @@ -244,7 +272,11 @@ func (s *Service) buildChatMessages(ctx context.Context, tx db.DBTX, thread *Thr
})
}
lastMessage := messages[len(messages)-1]
s.rewriteUserMessage(ctx, tx, &lastMessage)

// Rewrite user message using RAG
if thread.Metadata.RagSettings.Enable {
steps = s.rewriteUserMessage(ctx, tx, &lastMessage)
}

// Use the model from the last message
model := thread.Model
Expand Down Expand Up @@ -279,13 +311,14 @@ func (s *Service) buildChatMessages(ctx context.Context, tx db.DBTX, thread *Thr
}
}

return oaiMessages, model, metadata, nil
return oaiMessages, model, metadata, steps, nil
}

func (s *Service) rewriteUserMessage(ctx context.Context, tx db.DBTX, message *MessageDTO) {
func (s *Service) rewriteUserMessage(ctx context.Context, tx db.DBTX, message *MessageDTO) []llms.IntermediateStep {
steps := make([]llms.IntermediateStep, 0)
if message.Text == "" {
logger.FromContext(ctx).Info("RAG for user message: message text is empty")
return
return steps
}
// 1. search for similar documents
// 2. search chat history for similar questions
Expand All @@ -306,6 +339,7 @@ func (s *Service) rewriteUserMessage(ctx context.Context, tx db.DBTX, message *M
})
if err != nil {
logger.FromContext(ctx).Error("failed to search for similar documents", "err", err)
return steps
}
for _, d := range docsRes {
var metadata map[string]any
Expand All @@ -317,6 +351,12 @@ func (s *Service) rewriteUserMessage(ctx context.Context, tx db.DBTX, message *M
Metadata: metadata,
})
}
steps = append(steps, llms.IntermediateStep{
Type: llms.IntermediateStepRag,
Name: "vector_search",
Input: map[string]string{"query": message.Text},
Output: docs,
})

// chatHistoryRes, err := s.dao.SimilaritySearchMessages(ctx, tx, db.SimilaritySearchMessagesParams{
// ThreadID: pgtype.UUID{Bytes: message.ThreadID, Valid: true},
Expand All @@ -342,4 +382,5 @@ func (s *Service) rewriteUserMessage(ctx context.Context, tx db.DBTX, message *M
}

message.Text = prompt
return steps
}
Loading
Loading
0