// evals/runner.go
package main
import (
"context"
"encoding/json"
"fmt"
"math"
"os"
"time"
)
// EvalResult 单次评测结果
type EvalResult struct {
Query string `json:"query"`
Category string `json:"category"`
Difficulty string `json:"difficulty"`
RetrievedIDs []string `json:"retrieved_ids"`
ExpectedIDs []string `json:"expected_ids"`
GeneratedAnswer string `json:"generated_answer"`
ExpectedAnswer string `json:"expected_answer"`
PrecisionAt5 float64 `json:"precision_at_5"`
RecallAt5 float64 `json:"recall_at_5"`
MRR float64 `json:"mrr"`
NDCG float64 `json:"ndcg"`
LLMScore float64 `json:"llm_score"` // 1-5分
LLMReason string `json:"llm_reason"`
LatencyMs int64 `json:"latency_ms"`
Timestamp string `json:"timestamp"`
}
// EvalSummary 汇总统计
type EvalSummary struct {
TotalCases int `json:"total_cases"`
AvgPrecisionAt5 float64 `json:"avg_precision_at_5"`
AvgRecallAt5 float64 `json:"avg_recall_at_5"`
AvgMRR float64 `json:"avg_mrr"`
AvgNDCG float64 `json:"avg_ndcg"`
AvgLLMScore float64 `json:"avg_llm_score"`
AvgLatencyMs float64 `json:"avg_latency_ms"`
ByCategory map[string]CategoryMetrics `json:"by_category"`
ByDifficulty map[string]CategoryMetrics `json:"by_difficulty"`
}
type CategoryMetrics struct {
Count int `json:"count"`
AvgPrecision float64 `json:"avg_precision"`
AvgRecall float64 `json:"avg_recall"`
AvgLLMScore float64 `json:"avg_llm_score"`
}
// RAGPipeline 你的RAG pipeline接口
type RAGPipeline interface {
Retrieve(ctx context.Context, query string, k int) ([]string, error) // 返回chunk IDs
Generate(ctx context.Context, query string, chunkIDs []string) (string, error)
}
// LLMClient LLM调用接口
type LLMClient interface {
Complete(ctx context.Context, prompt string) (string, error)
}
// EvalRunner 评测框架主体
type EvalRunner struct {
pipeline RAGPipeline
llmClient LLMClient
k int // 检索Top-K
}
func NewEvalRunner(pipeline RAGPipeline, llmClient LLMClient, k int) *EvalRunner {
return &EvalRunner{
pipeline: pipeline,
llmClient: llmClient,
k: k,
}
}
// RunSingle 对单个case评测
func (r *EvalRunner) RunSingle(ctx context.Context, ec EvalCase) EvalResult {
start := time.Now()
result := EvalResult{
Query: ec.Query,
Category: ec.Category,
Difficulty: ec.Difficulty,
ExpectedIDs: ec.ExpectedSources,
ExpectedAnswer: ec.ExpectedAnswer,
Timestamp: start.Format(time.RFC3339),
}
// Step 1: 检索
retrievedIDs, err := r.pipeline.Retrieve(ctx, ec.Query, r.k)
if err != nil {
fmt.Printf("检索失败 [%s]: %v\n", ec.Query, err)
return result
}
result.RetrievedIDs = retrievedIDs
// Step 2: 计算检索指标
result.PrecisionAt5 = calcPrecisionAtK(retrievedIDs, ec.ExpectedSources, r.k)
result.RecallAt5 = calcRecallAtK(retrievedIDs, ec.ExpectedSources, r.k)
result.MRR = calcMRR(retrievedIDs, ec.ExpectedSources)
result.NDCG = calcNDCG(retrievedIDs, ec.ExpectedSources, r.k)
// Step 3: 生成答案
generatedAnswer, err := r.pipeline.Generate(ctx, ec.Query, retrievedIDs)
if err != nil {
fmt.Printf("生成失败 [%s]: %v\n", ec.Query, err)
}
result.GeneratedAnswer = generatedAnswer
// Step 4: LLM评分
score, reason := r.llmJudge(ctx, ec.Query, generatedAnswer, ec.ExpectedAnswer)
result.LLMScore = score
result.LLMReason = reason
result.LatencyMs = time.Since(start).Milliseconds()
return result
}
// Run 完整评测
func (r *EvalRunner) Run(ctx context.Context, dataset []EvalCase) (*EvalSummary, []EvalResult, error) {
var results []EvalResult
for i, ec := range dataset {
fmt.Printf("评测 [%d/%d]: %s\n", i+1, len(dataset), ec.Query)
result := r.RunSingle(ctx, ec)
results = append(results, result)
}
summary := r.summarize(results)
return summary, results, nil
}
// llmJudge 用LLM评分答案质量
func (r *EvalRunner) llmJudge(ctx context.Context, query, generated, expected string) (float64, string) {
prompt := fmt.Sprintf(`你是一个严格的QA评测专家。
用户问题:%s
参考答案:%s
系统生成的答案:%s
请评分(1-5分)并说明理由:
1分:完全错误或无关
2分:有部分相关内容,但主要信息缺失
3分:基本正确,但不完整或有轻微偏差
4分:正确且较完整
5分:完全正确、完整、清晰
只返回JSON格式:{"score": 数字, "reason": "简短理由"}`, query, expected, generated)
resp, err := r.llmClient.Complete(ctx, prompt)
if err != nil {
return 0, "LLM评分失败"
}
var judgeResult struct {
Score float64 `json:"score"`
Reason string `json:"reason"`
}
if err := json.Unmarshal([]byte(resp), &judgeResult); err != nil {
return 0, "解析评分失败"
}
return judgeResult.Score, judgeResult.Reason
}
// summarize 汇总统计
func (r *EvalRunner) summarize(results []EvalResult) *EvalSummary {
summary := &EvalSummary{
TotalCases: len(results),
ByCategory: make(map[string]CategoryMetrics),
ByDifficulty: make(map[string]CategoryMetrics),
}
var totalPrec, totalRecall, totalMRR, totalNDCG, totalScore, totalLatency float64
catMap := make(map[string][]EvalResult)
diffMap := make(map[string][]EvalResult)
for _, r := range results {
totalPrec += r.PrecisionAt5
totalRecall += r.RecallAt5
totalMRR += r.MRR
totalNDCG += r.NDCG
totalScore += r.LLMScore
totalLatency += float64(r.LatencyMs)
catMap[r.Category] = append(catMap[r.Category], r)
diffMap[r.Difficulty] = append(diffMap[r.Difficulty], r)
}
n := float64(len(results))
summary.AvgPrecisionAt5 = totalPrec / n
summary.AvgRecallAt5 = totalRecall / n
summary.AvgMRR = totalMRR / n
summary.AvgNDCG = totalNDCG / n
summary.AvgLLMScore = totalScore / n
summary.AvgLatencyMs = totalLatency / n
// 按类别汇总
for cat, rs := range catMap {
var p, rec, s float64
for _, r := range rs {
p += r.PrecisionAt5
rec += r.RecallAt5
s += r.LLMScore
}
cnt := float64(len(rs))
summary.ByCategory[cat] = CategoryMetrics{
Count: len(rs),
AvgPrecision: p / cnt,
AvgRecall: rec / cnt,
AvgLLMScore: s / cnt,
}
}
// 按难度汇总
for diff, rs := range diffMap {
var p, rec, s float64
for _, r := range rs {
p += r.PrecisionAt5
rec += r.RecallAt5
s += r.LLMScore
}
cnt := float64(len(rs))
summary.ByDifficulty[diff] = CategoryMetrics{
Count: len(rs),
AvgPrecision: p / cnt,
AvgRecall: rec / cnt,
AvgLLMScore: s / cnt,
}
}
return summary
}