Spaces:
Sleeping
Sleeping
# strategy.py | |
#TODO UPDATE Paths | |
from abc import ABC, abstractmethod | |
from typing import List, Tuple, Dict, Any | |
from llama_cpp import Llama | |
from langfuse.decorators import observe, langfuse_context | |
import os | |
# Initialize Langfuse | |
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae" | |
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af" | |
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" # 🇪🇺 EU region | |
try: | |
langfuse = Langfuse() | |
except Exception as e: | |
print("Langfuse Offline") | |
class GenerationStrategy(ABC): | |
"""Base class for generation strategies.""" | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
pass | |
class DefaultStrategy(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
output = generator.model.generate(input_ids, **model_kwargs) | |
return generator.tokenizer.decode(output[0], skip_special_tokens=True) | |
#def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
# | |
# tokenizer = generator.tokenizer | |
# model = generator.model.generate | |
# | |
# input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
# output = generator.model.generate(input_ids, **model_kwargs) | |
# return generator.tokenizer.decode(output[0], skip_special_tokens=True) | |
class MajorityVotingStrategy(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
outputs = [] | |
for _ in range(num_samples): | |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
output = generator.model.generate(input_ids, **model_kwargs) | |
outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True)) | |
return max(set(outputs), key=outputs.count) | |
class BestOfN(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
scored_outputs = [] | |
for _ in range(num_samples): | |
# Tokenize the prompt and move tensors to the appropriate device | |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
# Generate output from the main model | |
output = generator.model.generate(input_ids, **model_kwargs) | |
response = generator.tokenizer.decode(output[0], skip_special_tokens=True) | |
# Simple inference example | |
prm_output = generator.prm_model( | |
"<|system|>\n{system_message}</s>\n<|user|>\n{response}</s>\n<|assistant|>", # Prompt | |
max_tokens=512, # Generate up to 512 tokens | |
stop=["</s>"], # Example stop token - not necessarily correct for this specific model! Please check before using. | |
echo=True # Whether to echo the prompt | |
) | |
# Tokenize the response for scoring with the PRM model | |
#response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device) | |
# Pass the response to the PRM model based on its input requirements | |
#try: | |
# Example 1: If PRM model accepts BatchEncoding | |
# prm_output = generator.prm_model(response_inputs) | |
# Example 2: If PRM model expects only input_ids | |
# prm_output = generator.prm_model(response_inputs["input_ids"]) | |
# Example 3: If PRM model expects raw text | |
# prm_output = generator.prm_model(response) | |
# except Exception as e: | |
# print(f"Error with PRM model: {e}") | |
# score = 0.0 | |
# continue | |
# Calculate the score based on PRM output structure | |
score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0 | |
# Append the response and its score | |
scored_outputs.append((response, score)) | |
# Return the response with the highest score | |
return max(scored_outputs, key=lambda x: x[1])[0] | |
class BeamSearch(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
outputs = generator.model.generate( | |
input_ids, | |
num_beams=num_samples, | |
num_return_sequences=num_samples, | |
**model_kwargs | |
) | |
return [generator.tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
class DVT(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
results = [] | |
for _ in range(breadth): | |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
output = generator.model.generate(input_ids, **model_kwargs) | |
response = generator.tokenizer.decode(output[0], skip_special_tokens=True) | |
score = generator.prm_model(**generator.tokenizer(response, return_tensors="pt").to(generator.device)).logits.mean().item() | |
results.append((response, score)) | |
for _ in range(depth - 1): | |
best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] | |
for response, _ in best_responses: | |
input_ids = generator.tokenizer(response, return_tensors="pt").input_ids.to(generator.device) | |
output = generator.model.generate(input_ids, **model_kwargs) | |
extended_response = generator.tokenizer.decode(output[0], skip_special_tokens=True) | |
score = generator.prm_model(**generator.tokenizer(extended_response, return_tensors="pt").to(generator.device)).logits.mean().item() | |
results.append((extended_response, score)) | |
return max(results, key=lambda x: x[1])[0] | |
class COT(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
#TODO implement the chain of thought strategy | |
return "Not implemented yet" | |
class ReAct(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
#TODO implement the ReAct framework | |
return "Not implemented yet" | |
# Add other strategy implementations... | |