// internal/rag/pipeline.go
package rag
import (
"context"
"fmt"
"log/slog"
"time"
)
// Config RAG Pipeline配置
type Config struct {
VectorWeight float64 // Hybrid search向量权重
KeywordWeight float64 // Hybrid search关键词权重
TopK int // 初始检索K
RerankTopN int // Rerank后保留N个
ChunkSize int // 分块大小(tokens)
OverlapSize int // 重叠大小(tokens)
CacheEnabled bool // 是否启用查询缓存
CacheTTL time.Duration
}
func DefaultConfig() Config {
return Config{
VectorWeight: 0.7,
KeywordWeight: 0.3,
TopK: 20,
RerankTopN: 5,
ChunkSize: 900,
OverlapSize: 50,
CacheEnabled: true,
CacheTTL: 24 * time.Hour,
}
}
// Pipeline 完整RAG Pipeline
type Pipeline struct {
config Config
embedder Embedder
vectorDB VectorDB
keywordDB KeywordDB
reranker Reranker
llm LLMGenerator
cache QueryCache
logger *slog.Logger
}
// QueryResult 查询结果
type QueryResult struct {
Answer string `json:"answer"`
Sources []ChunkSource `json:"sources"`
Chunks []string `json:"chunks"`
Latency LatencyBreakdown `json:"latency"`
CacheHit bool `json:"cache_hit"`
}
type ChunkSource struct {
ChunkID string `json:"chunk_id"`
DocTitle string `json:"doc_title"`
Score float64 `json:"score"`
}
type LatencyBreakdown struct {
EmbedMs int64 `json:"embed_ms"`
SearchMs int64 `json:"search_ms"`
RerankMs int64 `json:"rerank_ms"`
GenerateMs int64 `json:"generate_ms"`
TotalMs int64 `json:"total_ms"`
}
// Query 执行完整RAG查询
func (p *Pipeline) Query(ctx context.Context, query string) (*QueryResult, error) {
start := time.Now()
latency := LatencyBreakdown{}
// Step 0: 查询缓存
if p.config.CacheEnabled {
if cached := p.cache.Get(query); cached != nil {
cached.CacheHit = true
return cached, nil
}
}
// Step 1: Embed query
embedStart := time.Now()
queryVec, err := p.embedder.Embed(ctx, query)
if err != nil {
return nil, fmt.Errorf("embed失败: %w", err)
}
latency.EmbedMs = time.Since(embedStart).Milliseconds()
// Step 2: Hybrid Search
searchStart := time.Now()
vectorIDs, err := p.vectorDB.Search(ctx, queryVec, p.config.TopK)
if err != nil {
return nil, fmt.Errorf("向量检索失败: %w", err)
}
keywordIDs, err := p.keywordDB.Search(ctx, query, p.config.TopK)
if err != nil {
p.logger.Warn("关键词检索失败,降级到pure vector", "error", err)
keywordIDs = vectorIDs // 降级
}
// RRF融合
fusedIDs := RRF(vectorIDs, keywordIDs, p.config.TopK)
latency.SearchMs = time.Since(searchStart).Milliseconds()
// Step 3: 获取chunks文本
chunks, err := p.vectorDB.GetChunks(ctx, fusedIDs)
if err != nil {
return nil, fmt.Errorf("获取chunks失败: %w", err)
}
// Step 4: Rerank
rerankStart := time.Now()
rankedChunks, err := p.reranker.Rerank(ctx, query, chunks, p.config.RerankTopN)
if err != nil {
p.logger.Warn("Rerank失败,使用原排序", "error", err)
rankedChunks = chunks
if len(rankedChunks) > p.config.RerankTopN {
rankedChunks = rankedChunks[:p.config.RerankTopN]
}
}
latency.RerankMs = time.Since(rerankStart).Milliseconds()
// Step 5: 构建Prompt并生成答案
genStart := time.Now()
prompt := buildRAGPrompt(query, rankedChunks)
answer, err := p.llm.Generate(ctx, prompt)
if err != nil {
return nil, fmt.Errorf("生成答案失败: %w", err)
}
latency.GenerateMs = time.Since(genStart).Milliseconds()
latency.TotalMs = time.Since(start).Milliseconds()
// Step 6: 构建结果
result := &QueryResult{
Answer: answer,
Chunks: extractTexts(rankedChunks),
Sources: extractSources(rankedChunks),
Latency: latency,
}
// 写入缓存
if p.config.CacheEnabled {
p.cache.Set(query, result, p.config.CacheTTL)
}
p.logger.Info("query完成",
"query", query,
"total_ms", latency.TotalMs,
"cache_hit", false,
"chunks_used", len(rankedChunks),
)
return result, nil
}
func buildRAGPrompt(query string, chunks []Chunk) string {
context := ""
for i, c := range chunks {
context += fmt.Sprintf("[文档%d - 来源:%s]\n%s\n\n", i+1, c.Source, c.Text)
}
return fmt.Sprintf(`你是一个企业知识库问答助手。
重要规则:
1. 只使用以下文档中的信息回答问题
2. 如果文档中没有相关信息,明确说"根据现有文档,无法回答此问题"
3. 不要使用任何文档之外的知识
4. 回答结构清晰,不超过300字
5. 在回答末尾注明使用的文档编号(如"参考:文档1、文档3")
参考文档:
%s
用户问题:%s
回答:`, context, query)
}
func extractTexts(chunks []Chunk) []string {
texts := make([]string, len(chunks))
for i, c := range chunks {
texts[i] = c.Text
}
return texts
}
func extractSources(chunks []Chunk) []ChunkSource {
sources := make([]ChunkSource, len(chunks))
for i, c := range chunks {
sources[i] = ChunkSource{
ChunkID: c.ID,
DocTitle: c.Title,
}
}
return sources
}
// Chunk 复用Day 8定义
type Chunk struct {
ID string
Text string
Title string
Source string
}
// 接口定义
type Embedder interface {
Embed(ctx context.Context, text string) ([]float32, error)
}
type VectorDB interface {
Search(ctx context.Context, vec []float32, k int) ([]string, error)
GetChunks(ctx context.Context, ids []string) ([]Chunk, error)
}
type KeywordDB interface {
Search(ctx context.Context, query string, k int) ([]string, error)
}
type Reranker interface {
Rerank(ctx context.Context, query string, chunks []Chunk, n int) ([]Chunk, error)
}
type LLMGenerator interface {
Generate(ctx context.Context, prompt string) (string, error)
}
type QueryCache interface {
Get(key string) *QueryResult
Set(key string, val *QueryResult, ttl time.Duration)
}