Chris4K commited on
Commit
52416ee
·
verified ·
1 Parent(s): ae307db

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +6 -6
services/strategy.py CHANGED
@@ -40,7 +40,7 @@ class DefaultStrategy(GenerationStrategy):
40
 
41
 
42
  class MajorityVotingStrategy(GenerationStrategy):
43
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
44
  outputs = []
45
  for _ in range(num_samples):
46
  input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
@@ -50,7 +50,7 @@ class MajorityVotingStrategy(GenerationStrategy):
50
 
51
 
52
  class BestOfN(GenerationStrategy):
53
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
54
  scored_outputs = []
55
  for _ in range(num_samples):
56
  input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
@@ -62,7 +62,7 @@ class BestOfN(GenerationStrategy):
62
 
63
 
64
  class BeamSearch(GenerationStrategy):
65
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
66
  input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
67
  outputs = generator.models["llama"].generate(
68
  input_ids,
@@ -74,7 +74,7 @@ class BeamSearch(GenerationStrategy):
74
 
75
 
76
  class DVT(GenerationStrategy):
77
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
78
  results = []
79
  for _ in range(breadth):
80
  input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
@@ -95,14 +95,14 @@ class DVT(GenerationStrategy):
95
 
96
 
97
  class COT(GenerationStrategy):
98
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
99
  #TODO implement the chain of thought strategy
100
 
101
  return "Not implemented yet"
102
 
103
 
104
  class ReAct(GenerationStrategy):
105
- def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs, num_samples: int = 5) -> str:
106
  #TODO implement the ReAct framework
107
  return "Not implemented yet"
108
  # Add other strategy implementations...
 
40
 
41
 
42
  class MajorityVotingStrategy(GenerationStrategy):
43
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
44
  outputs = []
45
  for _ in range(num_samples):
46
  input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
 
50
 
51
 
52
  class BestOfN(GenerationStrategy):
53
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
54
  scored_outputs = []
55
  for _ in range(num_samples):
56
  input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
 
62
 
63
 
64
  class BeamSearch(GenerationStrategy):
65
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
66
  input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
67
  outputs = generator.models["llama"].generate(
68
  input_ids,
 
74
 
75
 
76
  class DVT(GenerationStrategy):
77
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
78
  results = []
79
  for _ in range(breadth):
80
  input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device)
 
95
 
96
 
97
  class COT(GenerationStrategy):
98
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
99
  #TODO implement the chain of thought strategy
100
 
101
  return "Not implemented yet"
102
 
103
 
104
  class ReAct(GenerationStrategy):
105
+ def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
106
  #TODO implement the ReAct framework
107
  return "Not implemented yet"
108
  # Add other strategy implementations...