waleko commited on
Commit
7ab5a30
·
1 Parent(s): 53f2284

try fix gen

Browse files
Files changed (1) hide show
  1. infer.py +8 -2
infer.py CHANGED
@@ -203,13 +203,14 @@ class TikzGenerator:
203
  top_p=top_p,
204
  top_k=top_k,
205
  num_return_sequences=1,
206
- max_length=self.pipeline.tokenizer.model_max_length, # type: ignore
207
  do_sample=True,
208
  return_full_text=False,
209
  streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore
210
  skip_prompt=True,
211
  skip_special_tokens=True
212
  ),
 
213
  )
214
 
215
  if not stream:
@@ -218,8 +219,11 @@ class TikzGenerator:
218
  def generate(self, image: Image.Image, **generate_kwargs):
219
  prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
220
  tokenizer = self.pipeline.tokenizer
 
221
  text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
222
 
 
 
223
  if self.clean_up_output:
224
  for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
225
  # remove leading characters because skip_special_tokens in pipeline
@@ -236,7 +240,9 @@ class TikzGenerator:
236
  for artifact, replacement in artifacts.items():
237
  text = sub(artifact, replacement, text) # type: ignore
238
 
239
- return text
 
 
240
 
241
 
242
  def __call__(self, *args, **kwargs):
 
203
  top_p=top_p,
204
  top_k=top_k,
205
  num_return_sequences=1,
206
+ # max_length=self.pipeline.tokenizer.model_max_length, # type: ignore
207
  do_sample=True,
208
  return_full_text=False,
209
  streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore
210
  skip_prompt=True,
211
  skip_special_tokens=True
212
  ),
213
+ max_new_tokens=1024,
214
  )
215
 
216
  if not stream:
 
219
  def generate(self, image: Image.Image, **generate_kwargs):
220
  prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
221
  tokenizer = self.pipeline.tokenizer
222
+ print('starting generation')
223
  text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
224
 
225
+ print('text generated: ', text) # TODO: remove
226
+
227
  if self.clean_up_output:
228
  for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
229
  # remove leading characters because skip_special_tokens in pipeline
 
240
  for artifact, replacement in artifacts.items():
241
  text = sub(artifact, replacement, text) # type: ignore
242
 
243
+ print('cleaned text: ', text)
244
+
245
+ return TikzDocument(text)
246
 
247
 
248
  def __call__(self, *args, **kwargs):