Spaces:
Runtime error
Runtime error
try fix gen
Browse files
infer.py
CHANGED
@@ -203,13 +203,14 @@ class TikzGenerator:
|
|
203 |
top_p=top_p,
|
204 |
top_k=top_k,
|
205 |
num_return_sequences=1,
|
206 |
-
max_length=self.pipeline.tokenizer.model_max_length, # type: ignore
|
207 |
do_sample=True,
|
208 |
return_full_text=False,
|
209 |
streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore
|
210 |
skip_prompt=True,
|
211 |
skip_special_tokens=True
|
212 |
),
|
|
|
213 |
)
|
214 |
|
215 |
if not stream:
|
@@ -218,8 +219,11 @@ class TikzGenerator:
|
|
218 |
def generate(self, image: Image.Image, **generate_kwargs):
|
219 |
prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
|
220 |
tokenizer = self.pipeline.tokenizer
|
|
|
221 |
text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
|
222 |
|
|
|
|
|
223 |
if self.clean_up_output:
|
224 |
for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
|
225 |
# remove leading characters because skip_special_tokens in pipeline
|
@@ -236,7 +240,9 @@ class TikzGenerator:
|
|
236 |
for artifact, replacement in artifacts.items():
|
237 |
text = sub(artifact, replacement, text) # type: ignore
|
238 |
|
239 |
-
|
|
|
|
|
240 |
|
241 |
|
242 |
def __call__(self, *args, **kwargs):
|
|
|
203 |
top_p=top_p,
|
204 |
top_k=top_k,
|
205 |
num_return_sequences=1,
|
206 |
+
# max_length=self.pipeline.tokenizer.model_max_length, # type: ignore
|
207 |
do_sample=True,
|
208 |
return_full_text=False,
|
209 |
streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore
|
210 |
skip_prompt=True,
|
211 |
skip_special_tokens=True
|
212 |
),
|
213 |
+
max_new_tokens=1024,
|
214 |
)
|
215 |
|
216 |
if not stream:
|
|
|
219 |
def generate(self, image: Image.Image, **generate_kwargs):
|
220 |
prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
|
221 |
tokenizer = self.pipeline.tokenizer
|
222 |
+
print('starting generation')
|
223 |
text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
|
224 |
|
225 |
+
print('text generated: ', text) # TODO: remove
|
226 |
+
|
227 |
if self.clean_up_output:
|
228 |
for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
|
229 |
# remove leading characters because skip_special_tokens in pipeline
|
|
|
240 |
for artifact, replacement in artifacts.items():
|
241 |
text = sub(artifact, replacement, text) # type: ignore
|
242 |
|
243 |
+
print('cleaned text: ', text)
|
244 |
+
|
245 |
+
return TikzDocument(text)
|
246 |
|
247 |
|
248 |
def __call__(self, *args, **kwargs):
|