Week 2 Day 9:Embedding + Vector Store - 苏格拉底教学

💡 今天你会把昨天切好的 chunks 变成向量,存入向量数据库,实现语义相似度检索——RAG 的核心引擎。

第一部分:问题驱动

🤔 问题1:为什么要把文本变成向量?

引导问题:

  1. 用户问"如何重置密码",直接用关键词搜能找到"账号密码修改流程"这篇文档吗?
  2. 两段意思相近但用词完全不同的文本,怎么判断它们"相似"?
  3. 计算机能直接比较两段中文的"语义距离"吗?

答案揭示:

  • 关键词搜索只能匹配字面,无法理解语义
  • Embedding 把文本映射到高维空间的一个点(向量)
  • 语义相近的文本,在向量空间中距离更近
  • 相似度搜索 = 在向量空间中找"最近的邻居"

你应该理解:

文本 → Embedding 模型 → [0.23, -0.87, 0.44, ...] (1536 维向量) "重置密码" → [v1, v2, v3, ...] "账号密码修改流程" → [w1, w2, w3, ...] # v 和 w 在向量空间中距离很近!

🤔 问题2:OpenAI text-embedding-3-small 是什么?

引导问题:

  1. 为什么是 1536 维而不是 100 维?维度越高越好吗?
  2. $0.02/1M tokens 意味着 embedding 10 万个 chunk 要多少钱?
  3. 和 text-embedding-ada-002 比,3-small 有什么优势?

答案揭示:

  • 1536 维:更高维度 = 更细腻的语义表达,但也更耗存储和计算
  • text-embedding-3-small 比 ada-002 成本低 5 倍,性能持平或更优
  • 对于企业知识库(约 10 万 chunks),embedding 费用约 $0.002,几乎可忽略

关键模型对比:

模型 维度 价格(/1M tokens) 适用场景
text-embedding-3-small 1536 $0.02 大多数 RAG 场景(推荐)
text-embedding-3-large 3072 $0.13 高精度需求
text-embedding-ada-002 1536 $0.10 旧系统兼容

🤔 问题3:向量存储为什么不能用普通数据库?

引导问题:

  1. 你有 10 万个向量,要找最相近的 5 个,怎么做到毫秒级响应?
  2. 暴力遍历 10 万次距离计算有多慢?(1536 维 × 10 万次浮点运算)
  3. 为什么 PostgreSQL 的 B-Tree 索引对向量搜索没用?

答案揭示:

  • 向量搜索 = Approximate Nearest Neighbor(ANN)问题
  • 普通有序索引(B-Tree)无法加速高维空间搜索
  • 需要专用的向量索引:HNSW 或 IVF

第二部分:关键概念

HNSW vs IVF 向量索引

HNSW(Hierarchical Navigable Small World):

原理:构建多层图结构,高层稀疏(快速导航),低层密集(精确搜索) 优点:查询速度快(毫秒级),增量插入友好 缺点:内存消耗大(索引需要常驻内存) 适用:实时系统,数据量 < 5000 万

IVF(Inverted File Index):

原理:先用 K-means 聚类,查询时只在相似的簇内搜索 优点:内存效率高,适合超大规模数据 缺点:需要预训练,新数据插入需重建索引 适用:离线批量场景,数据量 > 1 亿

结论:企业知识库(10 万–1000 万),选 HNSW。


cosine similarity vs dot product

cosine similarity(余弦相似度):

sim(A, B) = (A·B) / (|A| × |B|) 范围:[-1, 1],1 表示完全相同方向 特点:只看方向,不管向量长度 适用:文本语义匹配(通用推荐)

dot product(点积):

sim(A, B) = A·B 范围:任意实数 特点:同时考虑方向和长度 适用:向量已归一化时等价于 cosine,速度稍快

OpenAI 的 embedding 输出已归一化,用 dot product 等价于 cosine,且计算更快。


Qdrant vs pgvector 选型对比

维度 Qdrant pgvector
部署 独立服务 PostgreSQL 插件
性能 专为向量优化 通用 DB,较慢
payload filter 原生支持 SQL WHERE
运维成本 额外维护一个服务 复用现有 PG
适用 向量搜索为核心业务 已有 PG,轻量使用

本项目选 Qdrant:专用向量 DB,后续混合检索、filter 扩展更方便。


第三部分:动手实现

✅ 版本1:单个 embedding 调用

// internal/rag/embedder.go
package rag

import (
	"context"
	"fmt"

	openai "github.com/sashabaranov/go-openai"
)

type Embedder struct {
	client *openai.Client
	model  string
}

func NewEmbedder(apiKey string) *Embedder {
	return &Embedder{
		client: openai.NewClient(apiKey),
		model:  string(openai.SmallEmbedding3), // text-embedding-3-small
	}
}

// EmbedOne 对单个文本生成 embedding
func (e *Embedder) EmbedOne(ctx context.Context, text string) ([]float32, error) {
	resp, err := e.client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
		Model: openai.EmbeddingModel(e.model),
		Input: []string{text},
	})
	if err != nil {
		return nil, fmt.Errorf("embedding failed: %w", err)
	}

	if len(resp.Data) == 0 {
		return nil, fmt.Errorf("no embedding returned")
	}

	return resp.Data[0].Embedding, nil
}

问题: 有 1000 个 chunks 要 embed,一个个调用需要多久?(提示:每次调用约 200ms,串行则需 200 秒)


✅ 版本2:并发批量 embedding(核心实现)

// internal/rag/embedder.go(续)

import (
	"context"
	"fmt"
	"log/slog"
	"sync"

	openai "github.com/sashabaranov/go-openai"
	"golang.org/x/sync/errgroup"
)

const (
	batchSize      = 20 // OpenAI 每次最多 100 个,保守用 20
	maxConcurrency = 5  // 最多同时 5 个并发请求
)

// EmbedBatch 并发批量 embedding
// errgroup 负责错误收集,semaphore 控制并发度
func (e *Embedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
	results := make([][]float32, len(texts))

	// semaphore:令牌桶,限制最大并发数
	sem := make(chan struct{}, maxConcurrency)

	g, gctx := errgroup.WithContext(ctx)
	var mu sync.Mutex

	// 分批处理
	for batchStart := 0; batchStart < len(texts); batchStart += batchSize {
		start := batchStart
		end := start + batchSize
		if end > len(texts) {
			end = len(texts)
		}
		batch := texts[start:end]

		g.Go(func() error {
			// 获取令牌(若已有 5 个并发则阻塞等待)
			sem <- struct{}{}
			defer func() { <-sem }()

			slog.Info("embedding batch", "start", start, "end", end, "size", len(batch))

			resp, err := e.client.CreateEmbeddings(gctx, openai.EmbeddingRequest{
				Model: openai.EmbeddingModel(e.model),
				Input: batch,
			})
			if err != nil {
				return fmt.Errorf("batch [%d:%d] failed: %w", start, end, err)
			}

			// 写回结果,加锁防止竞态
			mu.Lock()
			for i, data := range resp.Data {
				results[start+i] = data.Embedding
			}
			mu.Unlock()

			return nil
		})
	}

	// 等待所有 goroutine 完成,有任何错误则返回
	if err := g.Wait(); err != nil {
		return nil, err
	}

	return results, nil
}

思考题: semaphore 为什么用 chan struct{}?不用 sync.Mutex 的原因?

解析:

  • chan struct{} 是计数信号量,同一时刻可以有 maxConcurrency 个 goroutine 进入
  • sync.Mutex 是互斥锁,同一时刻只能有 1 个
  • 两者语义完全不同:Mutex 是"排他",chan 是"限流"

✅ 版本3:写入 Qdrant Vector DB

// internal/rag/vectorstore.go
package rag

import (
	"context"
	"fmt"

	"github.com/google/uuid"
	"github.com/qdrant/go-client/qdrant"
)

const (
	collectionName = "knowledge_base"
	vectorSize     = 1536 // text-embedding-3-small 维度
)

type VectorStore struct {
	client         *qdrant.Client
	collectionName string
}

func NewVectorStore(host string, port int) (*VectorStore, error) {
	client, err := qdrant.NewClient(&qdrant.Config{
		Host: host,
		Port: port,
	})
	if err != nil {
		return nil, fmt.Errorf("connect qdrant: %w", err)
	}

	return &VectorStore{
		client:         client,
		collectionName: collectionName,
	}, nil
}

// EnsureCollection 确保 collection 存在(幂等)
func (vs *VectorStore) EnsureCollection(ctx context.Context) error {
	exists, err := vs.client.CollectionExists(ctx, vs.collectionName)
	if err != nil {
		return err
	}
	if exists {
		return nil
	}

	return vs.client.CreateCollection(ctx, &qdrant.CreateCollection{
		CollectionName: vs.collectionName,
		VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
			Size:     vectorSize,
			Distance: qdrant.Distance_Cosine,
		}),
	})
}

// chunkIDToUUID 把 chunk ID 稳定转换为 UUID(不丢失原始信息)
func chunkIDToUUID(chunkID string) string {
	uid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(chunkID))
	return uid.String()
}

// UpsertChunks 批量写入 chunks 和对应向量
func (vs *VectorStore) UpsertChunks(ctx context.Context, chunks []Chunk, vectors [][]float32) error {
	if len(chunks) != len(vectors) {
		return fmt.Errorf("chunks(%d) and vectors(%d) length mismatch", len(chunks), len(vectors))
	}

	points := make([]*qdrant.PointStruct, len(chunks))
	for i, chunk := range chunks {
		uid := chunkIDToUUID(chunk.ID)
		points[i] = &qdrant.PointStruct{
			Id:      qdrant.NewIDWithUUID(uid),
			Vectors: qdrant.NewVectors(vectors[i]...),
			Payload: map[string]*qdrant.Value{
				"chunk_id":       qdrant.NewValueString(chunk.ID),
				"text":           qdrant.NewValueString(chunk.Text),
				"doc_id":         qdrant.NewValueString(chunk.DocID),
				"title":          qdrant.NewValueString(chunk.Title),
				"source":         qdrant.NewValueString(chunk.Source),
				"department":     qdrant.NewValueString(chunk.Metadata["department"]),
				"classification": qdrant.NewValueString(chunk.Metadata["classification"]),
				"chunk_index":    qdrant.NewValueString(chunk.Metadata["chunk_index"]),
			},
		}
	}

	_, err := vs.client.Upsert(ctx, &qdrant.UpsertPoints{
		CollectionName: vs.collectionName,
		Points:         points,
	})
	return err
}

// Search 向量相似度搜索
func (vs *VectorStore) Search(ctx context.Context, queryVec []float32, topK uint64) ([]*qdrant.ScoredPoint, error) {
	results, err := vs.client.Query(ctx, &qdrant.QueryPoints{
		CollectionName: vs.collectionName,
		Query:          qdrant.NewQuery(queryVec...),
		Limit:          &topK,
		WithPayload:    qdrant.NewWithPayload(true),
	})
	if err != nil {
		return nil, fmt.Errorf("vector search: %w", err)
	}

	return results, nil
}

✅ 版本4:完整流水线(从 chunks 到可查询)

// cmd/indexer/main.go
package main

import (
	"context"
	"encoding/json"
	"log/slog"
	"os"
	"time"

	"github.com/yourname/agent-runtime/internal/rag"
)

func main() {
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
	defer cancel()

	apiKey := os.Getenv("OPENAI_API_KEY")
	if apiKey == "" {
		slog.Error("OPENAI_API_KEY not set")
		os.Exit(1)
	}

	// 1. 读取 Day 8 生成的 chunks.jsonl
	chunks, err := loadChunks("evals/chunks.jsonl")
	if err != nil {
		slog.Error("load chunks failed", "error", err)
		os.Exit(1)
	}
	slog.Info("loaded chunks", "count", len(chunks))

	// 2. 提取文本,用 title + text 一起 embed,增强检索精度
	texts := make([]string, len(chunks))
	for i, c := range chunks {
		texts[i] = c.Title + "\n\n" + c.Text
	}

	// 3. 并发批量 embedding
	embedder := rag.NewEmbedder(apiKey)
	slog.Info("starting batch embedding", "chunks", len(texts))

	vectors, err := embedder.EmbedBatch(ctx, texts)
	if err != nil {
		slog.Error("embedding failed", "error", err)
		os.Exit(1)
	}
	slog.Info("embedding complete", "vectors", len(vectors))

	// 4. 初始化 Qdrant,确保 collection 存在
	vs, err := rag.NewVectorStore("localhost", 6334)
	if err != nil {
		slog.Error("vectorstore init failed", "error", err)
		os.Exit(1)
	}

	if err := vs.EnsureCollection(ctx); err != nil {
		slog.Error("ensure collection failed", "error", err)
		os.Exit(1)
	}

	// 5. 写入向量
	if err := vs.UpsertChunks(ctx, chunks, vectors); err != nil {
		slog.Error("upsert failed", "error", err)
		os.Exit(1)
	}

	slog.Info("indexing complete", "total_chunks", len(chunks))
}

func loadChunks(path string) ([]rag.Chunk, error) {
	file, err := os.Open(path)
	if err != nil {
		return nil, err
	}
	defer file.Close()

	var chunks []rag.Chunk
	decoder := json.NewDecoder(file)
	for decoder.More() {
		var c rag.Chunk
		if err := decoder.Decode(&c); err != nil {
			return nil, err
		}
		chunks = append(chunks, c)
	}
	return chunks, nil
}

✅ 版本5:启动 Qdrant 并验证

# 启动 Qdrant(本地开发)
docker run -d \
  --name qdrant \
  -p 6333:6333 \
  -p 6334:6334 \
  -v $(pwd)/data/qdrant:/qdrant/storage \
  qdrant/qdrant:latest

# 验证 Qdrant 是否运行
curl http://localhost:6333/collections

# 运行 indexer
OPENAI_API_KEY=sk-xxx go run cmd/indexer/main.go

# 验证数据写入
curl http://localhost:6333/collections/knowledge_base

# 验证点数量
curl http://localhost:6333/collections/knowledge_base/points/count

第四部分:关键概念深入

errgroup 使用规范

// 正确用法:g.Go 内必须捕获所有错误
g, gctx := errgroup.WithContext(parentCtx)

g.Go(func() error {
    // 注意:使用 gctx 而不是 parentCtx
    // 若其他 goroutine 出错,gctx 会被取消
    result, err := doWork(gctx)
    if err != nil {
        return err // errgroup 会收集第一个 error
    }
    // 处理 result...
    return nil
})

// g.Wait() 会:
// 1. 等所有 goroutine 完成
// 2. 返回第一个非 nil error
// 3. 同时取消 gctx(通知其他 goroutine 停止)
if err := g.Wait(); err != nil {
    return err
}

面试题: errgroup 的 ctx 取消后,正在运行的 goroutine 会自动停止吗?

答: 不会自动停止,goroutine 本身需要监听 ctx.Done():

g.Go(func() error {
    for {
        select {
        case <-gctx.Done():
            return gctx.Err() // 收到取消信号,主动退出
        default:
            // 继续工作
        }
    }
})

Embedding 最佳实践

// 1. query 端和 document 端必须用同一模型
// 错误:document 用 ada-002,query 用 3-small → 向量空间不兼容
// 正确:统一用 text-embedding-3-small

// 2. 加前缀可提升部分模型的精度(E5 系列推荐)
func prepareText(text string, isQuery bool) string {
    if isQuery {
        return "query: " + text
    }
    return "passage: " + text
}
// OpenAI 3-small 不强制要求前缀,但可以尝试

// 3. 批量请求远优于单个请求
// 串行 1000 次:约 200 秒
// 批量(20/批)+ 并发 5:约 2 秒(100x 加速)

// 4. 考虑 embedding 缓存(相同文本不重复调用)
type CachedEmbedder struct {
    inner *Embedder
    cache map[string][]float32
    mu    sync.RWMutex
}

第五部分:自测清单

运行前,问自己:

  • 我能解释 embedding 为什么能表示语义相似度?
  • errgroup 和 goroutine+WaitGroup 的核心区别是什么?
  • semaphore 用 chan struct{} 而不是 sync.Mutex 的原因?
  • HNSW 和 IVF 各适合什么场景?
  • cosine similarity vs dot product,OpenAI 的向量用哪个更合适?
  • Qdrant 重启后数据会丢失吗?(提示:看 docker run 的 -v 参数)
  • 为什么要用 SHA1 把 chunk ID 转成 UUID?

第六部分:作业

任务1:跑通 embedding 流水线

# 确保环境就绪
docker ps | grep qdrant       # Qdrant 运行中
ls evals/chunks.jsonl         # Day 8 生成的 chunks 存在

# 运行 indexer
OPENAI_API_KEY=sk-xxx go run cmd/indexer/main.go

# 验证
curl http://localhost:6333/collections/knowledge_base/points/count

任务2:实现简单查询验证

// cmd/search/main.go
// 测试:把 query embed 后搜索,人工判断返回 chunks 是否相关
func main() {
    query := "如何重置密码"
    // 1. embed query
    vec, _ := embedder.EmbedOne(ctx, query)
    // 2. 搜索 top-5
    results, _ := vs.Search(ctx, vec, 5)
    // 3. 打印返回 chunk 的 text,人工判断相关性
    for _, r := range results {
        fmt.Printf("score=%.4f text=%s\n",
            r.Score, r.Payload["text"].GetStringValue())
    }
}

任务3:性能分析

  • 记录 embedding 100 个 chunks 的总耗时
  • 记录向量搜索延迟(目标 P99 < 20ms)
  • 计算本次 embedding API 的费用(tokens × $0.02/1M)

任务4:思考题

  • 如果 chunks.jsonl 有 10 万行,当前实现会有什么瓶颈?
  • 向量存储满了怎么办?Qdrant 支持水平扩展吗?
  • 为什么 indexer 要用 context.WithTimeout(10 * time.Minute) 而不是无限等待?

第七部分:常见问题

Q: embedding API 返回 429(rate limit)怎么办?

A: 加指数退避重试:

import "github.com/cenkalti/backoff/v4"

operation := func() error {
    resp, err := e.client.CreateEmbeddings(ctx, req)
    if err != nil {
        return err
    }
    // 处理 resp...
    return nil
}

err := backoff.Retry(operation, backoff.WithMaxRetries(
    backoff.NewExponentialBackOff(), 3,
))

Q: 搜索结果不相关,怎么排查?

A: 常见原因:

  1. query 和 document embed 时文本格式不对称(一个加了前缀,一个没加)
  2. chunk 太短,语义信息不足(<50 词的 chunk 效果差)
  3. collection 的 distance 设置错了(应用 Cosine 不是 Euclid)
  4. 向量维度不匹配(更换模型后忘记重建 collection)

Q: Qdrant 的 collection 如何设置 HNSW 参数?

return vs.client.CreateCollection(ctx, &qdrant.CreateCollection{
    CollectionName: vs.collectionName,
    VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
        Size:     vectorSize,
        Distance: qdrant.Distance_Cosine,
        HnswConfig: &qdrant.HnswConfigDiff{
            M:           qdrant.PtrOf(uint64(16)),  // 每个节点的连接数,越大精度越高,内存越多
            EfConstruct: qdrant.PtrOf(uint64(100)), // 构建时的搜索宽度
        },
    }),
})

Q: 如何更新已存在的向量?

A: Qdrant 的 Upsert 是幂等的,相同 ID 会覆盖旧数据:

// 文档更新时:重新生成 chunk 和 embedding,用相同 UUID upsert 即可
// 注意:chunk ID → UUID 的映射必须保持不变(用同一个 hash 函数)

第八部分:配套算法

算法1:层序遍历(Level Order Traversal)

为什么看这道题? HNSW 的图结构遍历和 BFS 层序有异曲同工之妙:从入口节点出发,按层搜索最近邻。

// LeetCode 102: Binary Tree Level Order Traversal
// 返回每层节点值的列表
func levelOrder(root *TreeNode) [][]int {
    if root == nil {
        return nil
    }

    var result [][]int
    queue := []*TreeNode{root}

    for len(queue) > 0 {
        levelSize := len(queue) // 记录当前层的节点数(关键!)
        var level []int

        for i := 0; i < levelSize; i++ {
            node := queue[0]
            queue = queue[1:]

            level = append(level, node.Val)

            if node.Left != nil {
                queue = append(queue, node.Left)
            }
            if node.Right != nil {
                queue = append(queue, node.Right)
            }
        }

        result = append(result, level)
    }

    return result
}

关键点: 进入内层循环前记录 levelSize,确保只处理当前层,不混入下一层。

复杂度: Time O(n),Space O(n)(最宽一层的节点数)

变体: 锯齿形层序(LeetCode 103)——奇数层从左到右,偶数层从右到左:

// 只需在追加 level 时按层奇偶决定是否 reverse
if len(result)%2 == 1 {
    // 反转 level
    for i, j := 0, len(level)-1; i < j; i, j = i+1, j-1 {
        level[i], level[j] = level[j], level[i]
    }
}
result = append(result, level)

算法2:验证二叉搜索树(Validate BST)

为什么? 向量索引维护高维空间的有序性,BST 是有序数据结构的经典。理解约束传递对于理解索引约束很有帮助。

// LeetCode 98: Validate Binary Search Tree
func isValidBST(root *TreeNode) bool {
    return validate(root, nil, nil)
}

// 用 min/max 边界传递约束,而不是只比较父子节点
// 错误做法:只看 node.Val > node.Left.Val(遗漏跨层约束)
// 正确做法:整个左子树的所有节点都必须 < 当前节点
func validate(node *TreeNode, min, max *int) bool {
    if node == nil {
        return true
    }

    if min != nil && node.Val <= *min {
        return false
    }
    if max != nil && node.Val >= *max {
        return false
    }

    // 左子树:上界更新为当前节点值
    // 右子树:下界更新为当前节点值
    return validate(node.Left, min, &node.Val) &&
        validate(node.Right, &node.Val, max)
}

陷阱: 只比较相邻父子节点会遗漏跨层约束。经典反例:

5 / \ 1 4 / \ 3 6 # 4 < 5 看起来合法,但 3 < 5 违反 BST:右子树所有节点都必须 > 5

复杂度: Time O(n),Space O(h)(递归栈深度)


算法3:BST 第 K 小元素(Kth Smallest in BST)

为什么? 向量搜索返回 top-K,理解在有序结构中高效找第 K 个元素,有助于设计高效检索逻辑。

// LeetCode 230: Kth Smallest Element in a BST
// 思路:中序遍历 BST 得到升序序列,第 k 个即答案
func kthSmallest(root *TreeNode, k int) int {
    count := 0
    result := 0

    var inorder func(node *TreeNode)
    inorder = func(node *TreeNode) {
        if node == nil || count >= k {
            return
        }

        inorder(node.Left)

        count++
        if count == k {
            result = node.Val
            return
        }

        inorder(node.Right)
    }

    inorder(root)
    return result
}

// 进阶:如果频繁查询 kth smallest,怎么优化到 O(log n)?
// 思路:在每个节点额外记录左子树大小
type AugmentedNode struct {
    Val       int
    Left      *AugmentedNode
    Right     *AugmentedNode
    LeftCount int // 左子树节点总数
}

func kthSmallestAugmented(root *AugmentedNode, k int) int {
    if root == nil {
        return -1
    }
    leftCount := root.LeftCount
    if k == leftCount+1 {
        return root.Val // 当前节点就是第 k 小
    } else if k <= leftCount {
        return kthSmallestAugmented(root.Left, k)
    } else {
        return kthSmallestAugmented(root.Right, k-leftCount-1)
    }
}

复杂度: 基础版 O(n),增强版 O(h)(h 为树高)


下一步:Day 10 预告

明天我们会:

  1. 实现 BM25 关键词搜索(bleve 库)
  2. 用 Reciprocal Rank Fusion(RRF)融合向量搜索和关键词搜索
  3. 加入 Metadata filter(按部门、密级过滤)

准备问题:

  • 如果用户搜索"GPT-4o",向量搜索为什么可能找不到?
  • BM25 和 TF-IDF 的区别是什么?
  • 为什么不直接平均两个搜索的分数,而要用 RRF?