Spaces:
Running
Running
File size: 5,192 Bytes
41b743c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import json
import re
import time
from transformers import GPT2Tokenizer
from utils import model_prompting, f1_score, exact_match_score, get_bert_score
from beartype.typing import Any, Dict, List, Tuple, Optional
# Initialize tokenizer for token counting (used in cost calculation)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
class LLMEngine:
"""
A class to manage interactions with multiple language models and evaluate their performance.
Handles model selection, querying, cost calculation, and performance evaluation
using various metrics for different tasks.
"""
def __init__(self, llm_names: List[str], llm_description: Dict[str, Dict[str, Any]]):
"""
Initialize the LLM Engine with available models and their descriptions.
Args:
llm_names: List of language model names available in the engine
llm_description: Dictionary containing model configurations and pricing details
Structure: {
"model_name": {
"model": "api_identifier",
"input_price": cost_per_input_token,
"output_price": cost_per_output_token,
...
},
...
}
"""
self.llm_names = llm_names
self.llm_description = llm_description
def compute_cost(self, llm_idx: int, input_text: str, output_size: int) -> float:
"""
Calculate the cost of a model query based on input and output token counts.
Args:
llm_idx: Index of the model in the llm_names list
input_text: The input prompt sent to the model
output_size: Number of tokens in the model's response
Returns:
float: The calculated cost in currency units
"""
# Count input tokens
input_size = len(tokenizer(input_text)['input_ids'])
# Get pricing information for the selected model
llm_name = self.llm_names[llm_idx]
input_price = self.llm_description[llm_name]["input_price"]
output_price = self.llm_description[llm_name]["output_price"]
# Calculate total cost
cost = input_size * input_price + output_size * output_price
return cost
def get_llm_response(self, query: str, llm_idx: int) -> str:
"""
Send a query to a language model and get its response.
Args:
query: The prompt text to send to the model
llm_idx: Index of the model in the llm_names list
Returns:
str: The model's text response
Note:
Includes a retry mechanism with a 2-second delay if the first attempt fails
"""
llm_name = self.llm_names[llm_idx]
model = self.llm_description[llm_name]["model"]
try:
response = model_prompting(llm_model=model, prompt=query)
except:
# If the request fails, wait and retry once
time.sleep(2)
response = model_prompting(llm_model=model, prompt=query)
return response
def eval(self, prediction: str, ground_truth: str, metric: str) -> float:
"""
Evaluate the model's prediction against the ground truth using the specified metric.
Args:
prediction: The model's output text
ground_truth: The correct expected answer
metric: The evaluation metric to use (e.g., 'em', 'f1_score', 'GSM8K')
task_id: Optional identifier for the specific task being evaluated
Returns:
float: Evaluation score (typically between 0 and 1)
"""
# Exact match evaluation
if metric == 'em':
result = exact_match_score(prediction, ground_truth)
return float(result)
# Multiple choice exact match
elif metric == 'em_mc':
result = exact_match_score(prediction, ground_truth, normal_method="mc")
return float(result)
# BERT-based semantic similarity score
elif metric == 'bert_score':
result = get_bert_score([prediction], [ground_truth])
return result
# GSM8K math problem evaluation
# Extracts the final answer from the format "<answer>" and checks against ground truth
elif metric == 'GSM8K':
# Extract the final answer from ground truth (after the "####" delimiter)
ground_truth = ground_truth.split("####")[-1].strip()
# Look for an answer enclosed in angle brackets <X>
match = re.search(r'\<(\d+)\>', prediction)
if match:
if match.group(1) == ground_truth:
return 1 # Correct answer
else:
return 0 # Incorrect answer
else:
return 0 # No answer in expected format
# F1 score for partial matching (used in QA tasks)
elif metric == 'f1_score':
f1, prec, recall = f1_score(prediction, ground_truth)
return f1
# Default case for unrecognized metrics
else:
return 0 |