Vipitis commited on
Commit
428da0b
·
1 Parent(s): c11be31

massively improving speed

Browse files
Files changed (1) hide show
  1. ShaderEval.py +13 -1
ShaderEval.py CHANGED
@@ -94,9 +94,21 @@ class ReturnGenerationEvaluator(evaluate.TextGenerationEvaluator):
94
  raise ValueError(
95
  f"Incompatible `model_or_pipeline`. Please specify `model_or_pipeline` compatible with the `{self.task}` task."
96
  )
 
 
 
 
 
 
 
 
97
  return pipe
98
 
99
- def _resolve_context_lenght(self, model_or_pipeline=None): #TODO should really copy the typing hints here.
 
 
 
 
100
  # tokenizer needs to know the context length for our pipe strategy, but it has to be passed to the tokenizer, not model.
101
  # the tokenizer should read from the model config, but that can be wrong, or it has a task overwrite (for "text-generation" for example you get 50)
102
  #model_or_pipeline only exists via the .compute call, so we have to take it in
 
94
  raise ValueError(
95
  f"Incompatible `model_or_pipeline`. Please specify `model_or_pipeline` compatible with the `{self.task}` task."
96
  )
97
+
98
+ # fixinging default for max_lenght
99
+ pipe.model.config.max_length = self._resolve_context_lenght(pipe=pipe)
100
+
101
+ # specify eos tokens to be all of those that include a ; so we can stop early.
102
+ self.PIPELINE_KWARGS.update({"eos_token_id": [v for k,v in pipe.tokenizer.vocab.items() if ";" in k]}) #didn't see that this was passed all the way already.
103
+ # solution found here: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.eos_token_id but does it actually work?
104
+
105
  return pipe
106
 
107
+ def _resolve_context_lenght(self, model_or_pipeline=None, pipe=None): #TODO should really copy the typing hints here.
108
+ if isinstance(model_or_pipeline, transformers.GPT2Model): # you are comparing a string here -.-
109
+ return model_or_pipeline.config.n_ctx # how GPT2 models might handle is, seen with
110
+ if pipe is not None: #should I figure out a way to pass this.
111
+ pipe.tokenizer.model_max_length # this is set to something small for pipeline default task, but we would want to put it to the max instead.
112
  # tokenizer needs to know the context length for our pipe strategy, but it has to be passed to the tokenizer, not model.
113
  # the tokenizer should read from the model config, but that can be wrong, or it has a task overwrite (for "text-generation" for example you get 50)
114
  #model_or_pipeline only exists via the .compute call, so we have to take it in