Chris4K commited on
Commit
26304eb
·
verified ·
1 Parent(s): e05a0c2

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +9 -9
services/strategy.py CHANGED
@@ -18,7 +18,7 @@ try:
18
  except Exception as e:
19
  print("Langfuse Offline")
20
 
21
- @observe()
22
  class GenerationStrategy(ABC):
23
  """Base class for generation strategies."""
24
 
@@ -26,15 +26,15 @@ class GenerationStrategy(ABC):
26
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
27
  pass
28
 
29
-
30
  class DefaultStrategy(GenerationStrategy):
31
-
32
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
33
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
34
  output = generator.model.generate(input_ids, **model_kwargs)
35
  return generator.tokenizer.decode(output[0], skip_special_tokens=True)
36
 
37
- @observe()
38
  class MajorityVotingStrategy(GenerationStrategy):
39
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
40
  outputs = []
@@ -44,7 +44,7 @@ class MajorityVotingStrategy(GenerationStrategy):
44
  outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
45
  return max(set(outputs), key=outputs.count)
46
 
47
- @observe()
48
  class BestOfN(GenerationStrategy):
49
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
50
  scored_outputs = []
@@ -56,7 +56,7 @@ class BestOfN(GenerationStrategy):
56
  scored_outputs.append((response, score))
57
  return max(scored_outputs, key=lambda x: x[1])[0]
58
 
59
- @observe()
60
  class BeamSearch(GenerationStrategy):
61
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
62
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
@@ -68,7 +68,7 @@ class BeamSearch(GenerationStrategy):
68
  )
69
  return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
70
 
71
- @observe()
72
  class DVT(GenerationStrategy):
73
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
74
  results = []
@@ -89,14 +89,14 @@ class DVT(GenerationStrategy):
89
  results.append((extended_response, score))
90
  return max(results, key=lambda x: x[1])[0]
91
 
92
- @observe()
93
  class COT(GenerationStrategy):
94
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
95
  #TODO implement the chain of thought strategy
96
 
97
  return "Not implemented yet"
98
 
99
- @observe()
100
  class ReAct(GenerationStrategy):
101
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
102
  #TODO implement the ReAct framework
 
18
  except Exception as e:
19
  print("Langfuse Offline")
20
 
21
+
22
  class GenerationStrategy(ABC):
23
  """Base class for generation strategies."""
24
 
 
26
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
27
  pass
28
 
29
+
30
  class DefaultStrategy(GenerationStrategy):
31
+ @observe()
32
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str:
33
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
34
  output = generator.model.generate(input_ids, **model_kwargs)
35
  return generator.tokenizer.decode(output[0], skip_special_tokens=True)
36
 
37
+
38
  class MajorityVotingStrategy(GenerationStrategy):
39
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
40
  outputs = []
 
44
  outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True))
45
  return max(set(outputs), key=outputs.count)
46
 
47
+
48
  class BestOfN(GenerationStrategy):
49
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
50
  scored_outputs = []
 
56
  scored_outputs.append((response, score))
57
  return max(scored_outputs, key=lambda x: x[1])[0]
58
 
59
+
60
  class BeamSearch(GenerationStrategy):
61
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
62
  input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
 
68
  )
69
  return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
70
 
71
+
72
  class DVT(GenerationStrategy):
73
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
74
  results = []
 
89
  results.append((extended_response, score))
90
  return max(results, key=lambda x: x[1])[0]
91
 
92
+
93
  class COT(GenerationStrategy):
94
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
95
  #TODO implement the chain of thought strategy
96
 
97
  return "Not implemented yet"
98
 
99
+
100
  class ReAct(GenerationStrategy):
101
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
102
  #TODO implement the ReAct framework