Chris4K commited on
Commit
f59b437
·
verified ·
1 Parent(s): 52416ee

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +11 -7
services/strategy.py CHANGED
@@ -30,13 +30,17 @@ class GenerationStrategy(ABC):
30
  class DefaultStrategy(GenerationStrategy):
31
  @observe()
32
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
33
-
34
- tokenizer = generator.tokenizers["llama"]
35
- model = generator.models["llama"].generate
36
-
37
- input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
38
- output = generator.models["llama"].generate(input_ids, **model_kwargs)
39
- return generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
 
 
 
 
40
 
41
 
42
  class MajorityVotingStrategy(GenerationStrategy):
 
30
  class DefaultStrategy(GenerationStrategy):
31
  @observe()
32
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
33
+ input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
34
+ output = generator.model.generate(input_ids, **model_kwargs)
35
+ return generator.tokenizer.decode(output[0], skip_special_tokens=True)
36
+ #def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
37
+ #
38
+ # tokenizer = generator.tokenizers["llama"]
39
+ # model = generator.models["llama"].generate
40
+ #
41
+ # input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
42
+ # output = generator.models["llama"].generate(input_ids, **model_kwargs)
43
+ # return generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)
44
 
45
 
46
  class MajorityVotingStrategy(GenerationStrategy):