Chris4K commited on
Commit
0477e0c
·
verified ·
1 Parent(s): 47e661d

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +3 -3
services/strategy.py CHANGED
@@ -31,7 +31,7 @@ class DefaultStrategy(GenerationStrategy):
31
  @observe()
32
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
33
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
34
- output = generator.model.generate(input_ids, **model_kwargs)
35
  return generator.tokenizer.decode(output[0], skip_special_tokens=True)
36
 
37
 
@@ -40,7 +40,7 @@ class MajorityVotingStrategy(GenerationStrategy):
40
  outputs = []
41
  for _ in range(num_samples):
42
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
43
- output = generator.model.generate(input_ids, **model_kwargs)
44
  outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
45
  return max(set(outputs), key=outputs.count)
46
 
@@ -66,7 +66,7 @@ class BeamSearch(GenerationStrategy):
66
  num_return_sequences=num_samples,
67
  **model_kwargs
68
  )
69
- return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
70
 
71
 
72
  class DVT(GenerationStrategy):
 
31
  @observe()
32
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
33
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
34
+ output = generator.generate(input_ids, **model_kwargs)
35
  return generator.tokenizer.decode(output[0], skip_special_tokens=True)
36
 
37
 
 
40
  outputs = []
41
  for _ in range(num_samples):
42
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
43
+ output = generator.generate(input_ids, **model_kwargs)
44
  outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
45
  return max(set(outputs), key=outputs.count)
46
 
 
66
  num_return_sequences=num_samples,
67
  **model_kwargs
68
  )
69
+ return [generator.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
70
 
71
 
72
  class DVT(GenerationStrategy):