Spaces:
Sleeping
Sleeping
# strategy.py | |
#TODO UPDATE Paths | |
from abc import ABC, abstractmethod | |
from typing import List, Tuple | |
class GenerationStrategy(ABC): | |
"""Base class for generation strategies.""" | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
pass | |
class DefaultStrategy(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
output = generator.model.generate(input_ids, **model_kwargs) | |
return generator.tokenizer.decode(output[0], skip_special_tokens=True) | |
class MajorityVotingStrategy(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
outputs = [] | |
for _ in range(num_samples): | |
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
output = generator.model.generate(input_ids, **model_kwargs) | |
outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True)) | |
return max(set(outputs), key=outputs.count) | |
class BestOfN(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
scored_outputs = [] | |
for _ in range(num_samples): | |
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
output = self.llama_model.generate(input_ids, **model_kwargs) | |
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() | |
scored_outputs.append((response, score)) | |
return max(scored_outputs, key=lambda x: x[1])[0] | |
class BeamSearch(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
outputs = self.llama_model.generate( | |
input_ids, | |
num_beams=num_samples, | |
num_return_sequences=num_samples, | |
**model_kwargs | |
) | |
return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
class DVT(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
results = [] | |
for _ in range(breadth): | |
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
output = self.llama_model.generate(input_ids, **model_kwargs) | |
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() | |
results.append((response, score)) | |
for _ in range(depth - 1): | |
best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] | |
for response, _ in best_responses: | |
input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device) | |
output = self.llama_model.generate(input_ids, **model_kwargs) | |
extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item() | |
results.append((extended_response, score)) | |
return max(results, key=lambda x: x[1])[0] | |
class COT(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
#TODO implement the chain of thought strategy | |
return "Not implemented yet" | |
class ReAct(GenerationStrategy): | |
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
#TODO implement the ReAct framework | |
return "Not implemented yet" | |
# Add other strategy implementations... | |