|
|
|
|
|
from abc import ABC, abstractmethod |
|
from typing import List, Tuple |
|
|
|
@observe() |
|
class GenerationStrategy(ABC): |
|
"""Base class for generation strategies.""" |
|
|
|
@abstractmethod |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: |
|
pass |
|
|
|
|
|
class DefaultStrategy(GenerationStrategy): |
|
|
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: |
|
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) |
|
output = generator.model.generate(input_ids, **model_kwargs) |
|
return generator.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
@observe() |
|
class MajorityVotingStrategy(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
outputs = [] |
|
for _ in range(num_samples): |
|
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) |
|
output = generator.model.generate(input_ids, **model_kwargs) |
|
outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True)) |
|
return max(set(outputs), key=outputs.count) |
|
|
|
@observe() |
|
class BestOfN(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
scored_outputs = [] |
|
for _ in range(num_samples): |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() |
|
scored_outputs.append((response, score)) |
|
return max(scored_outputs, key=lambda x: x[1])[0] |
|
|
|
@observe() |
|
class BeamSearch(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
outputs = self.llama_model.generate( |
|
input_ids, |
|
num_beams=num_samples, |
|
num_return_sequences=num_samples, |
|
**model_kwargs |
|
) |
|
return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
|
|
|
@observe() |
|
class DVT(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
results = [] |
|
for _ in range(breadth): |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() |
|
results.append((response, score)) |
|
|
|
for _ in range(depth - 1): |
|
best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] |
|
for response, _ in best_responses: |
|
input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item() |
|
results.append((extended_response, score)) |
|
return max(results, key=lambda x: x[1])[0] |
|
|
|
@observe() |
|
class COT(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
|
|
|
|
return "Not implemented yet" |
|
|
|
@observe() |
|
class ReAct(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
|
|
return "Not implemented yet" |
|
|
|
|