Spaces:
Sleeping
Sleeping
File size: 6,682 Bytes
0bcc252 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import fs from 'fs/promises';
import {exec} from 'child_process';
import {promisify} from 'util';
import {getResponse} from '../agent';
import {generateObject} from 'ai';
import {GEMINI_API_KEY} from '../config';
import {z} from 'zod';
import {AnswerAction, TrackerContext} from "../types";
import {createGoogleGenerativeAI} from "@ai-sdk/google";
const execAsync = promisify(exec);
interface Question {
question: string;
answer: string;
}
interface EvaluationResult {
pass: boolean;
reason: string;
total_steps: number;
total_tokens: number;
question: string;
expected_answer: string;
actual_answer: string;
}
interface EvaluationStats {
model_name: string;
pass_rate: number;
avg_steps: number;
max_steps: number;
min_steps: number;
median_steps: number;
avg_tokens: number;
median_tokens: number;
max_tokens: number;
min_tokens: number;
}
function calculateMedian(numbers: number[]): number {
const sorted = [...numbers].sort((a, b) => a - b);
const middle = Math.floor(sorted.length / 2);
if (sorted.length % 2 === 0) {
return (sorted[middle - 1] + sorted[middle]) / 2;
}
return sorted[middle];
}
function calculateStats(results: EvaluationResult[], modelName: string): EvaluationStats {
const steps = results.map(r => r.total_steps);
const tokens = results.map(r => r.total_tokens);
const passCount = results.filter(r => r.pass).length;
return {
model_name: modelName,
pass_rate: (passCount / results.length) * 100,
avg_steps: steps.reduce((a, b) => a + b, 0) / steps.length,
max_steps: Math.max(...steps),
min_steps: Math.min(...steps),
median_steps: calculateMedian(steps),
avg_tokens: tokens.reduce((a, b) => a + b, 0) / tokens.length,
median_tokens: calculateMedian(tokens),
max_tokens: Math.max(...tokens),
min_tokens: Math.min(...tokens)
};
}
function printStats(stats: EvaluationStats): void {
console.log('\n=== Evaluation Statistics ===');
console.log(`Model: ${stats.model_name}`);
console.log(`Pass Rate: ${stats.pass_rate.toFixed(0)}%`);
console.log(`Average Steps: ${stats.avg_steps.toFixed(0)}`);
console.log(`Maximum Steps: ${stats.max_steps}`);
console.log(`Minimum Steps: ${stats.min_steps}`);
console.log(`Median Steps: ${stats.median_steps.toFixed(0)}`);
console.log(`Average Tokens: ${stats.avg_tokens.toFixed(0)}`);
console.log(`Median Tokens: ${stats.median_tokens.toFixed(0)}`);
console.log(`Maximum Tokens: ${stats.max_tokens}`);
console.log(`Minimum Tokens: ${stats.min_tokens}`);
console.log('===========================\n');
}
async function getCurrentGitCommit(): Promise<string> {
try {
const {stdout} = await execAsync('git rev-parse --short HEAD');
return stdout.trim();
} catch (error) {
console.error('Error getting git commit:', error);
return 'unknown';
}
}
async function evaluateAnswer(expectedAnswer: string, actualAnswer: string): Promise<{ pass: boolean; reason: string }> {
const prompt = `You are a deterministic evaluator with zero temperature. Compare the following expected answer with the actual answer and determine if they convey the same information.
Expected answer: ${expectedAnswer}
Actual answer: ${actualAnswer}
Minor wording differences are acceptable as long as the core information of the expected answer is preserved in the actual answer.'`;
const schema = z.object({
pass: z.boolean().describe('Whether the actual answer matches the expected answer'),
reason: z.string().describe('Detailed explanation of why the evaluation passed or failed')
});
try {
const result = await generateObject({
model: createGoogleGenerativeAI({ apiKey: GEMINI_API_KEY })('gemini-2.0-flash'), // fix to gemini-2.0-flash for evaluation
schema,
prompt,
maxTokens: 1000,
temperature: 0 // Setting temperature to 0 for deterministic output
});
return result.object;
} catch (error) {
console.error('Evaluation failed:', error);
return {
pass: false,
reason: `Evaluation error: ${error}`
};
}
}
async function batchEvaluate(inputFile: string): Promise<void> {
// Read and parse input file
const questions: Question[] = JSON.parse(await fs.readFile(inputFile, 'utf-8'));
const results: EvaluationResult[] = [];
const gitCommit = await getCurrentGitCommit();
const modelName = process.env.DEFAULT_MODEL_NAME || 'unknown';
const outputFile = `eval-${gitCommit}-${modelName}.json`;
// Process each question
for (let i = 0; i < questions.length; i++) {
const {question, answer: expectedAnswer} = questions[i];
console.log(`\nProcessing question ${i + 1}/${questions.length}: ${question}`);
try {
// Get response using the agent
const {
result: response,
context
} = await getResponse(question) as { result: AnswerAction; context: TrackerContext };
// Get response using the streaming agent
// const {
// result: response,
// context
// } = await getResponseStreamingAgent(question) as { result: AnswerAction; context: TrackerContext };
const actualAnswer = response.answer;
// Evaluate the response
const evaluation = await evaluateAnswer(expectedAnswer, actualAnswer);
// Record results
results.push({
pass: evaluation.pass,
reason: evaluation.reason,
total_steps: context.actionTracker.getState().totalStep,
total_tokens: context.tokenTracker.getTotalUsage().totalTokens,
question,
expected_answer: expectedAnswer,
actual_answer: actualAnswer
});
console.log(`Evaluation: ${evaluation.pass ? 'PASS' : 'FAIL'}`);
console.log(`Reason: ${evaluation.reason}`);
} catch (error) {
console.error(`Error processing question: ${question}`, error);
results.push({
pass: false,
reason: `Error: ${error}`,
total_steps: 0,
total_tokens: 0,
question,
expected_answer: expectedAnswer,
actual_answer: 'Error occurred'
});
}
}
// Calculate and print statistics
const stats = calculateStats(results, modelName);
printStats(stats);
// Save results
await fs.writeFile(outputFile, JSON.stringify({
results,
statistics: stats
}, null, 2));
console.log(`\nEvaluation results saved to ${outputFile}`);
}
// Run batch evaluation if this is the main module
if (require.main === module) {
const inputFile = process.argv[2];
if (!inputFile) {
console.error('Please provide an input file path');
process.exit(1);
}
batchEvaluate(inputFile).catch(console.error);
}
export {batchEvaluate};
|