Spaces:
Runtime error
Runtime error
Update chatbot.py
Browse files- chatbot.py +13 -12
chatbot.py
CHANGED
@@ -315,7 +315,7 @@ def model_inference(
|
|
315 |
temperature,
|
316 |
max_new_tokens,
|
317 |
repetition_penalty,
|
318 |
-
|
319 |
web_search,
|
320 |
):
|
321 |
# Define generation_args at the beginning of the function
|
@@ -332,6 +332,7 @@ def model_inference(
|
|
332 |
generate_kwargs = dict(
|
333 |
max_new_tokens=4000,
|
334 |
do_sample=True,
|
|
|
335 |
)
|
336 |
# Format the prompt for the language model
|
337 |
formatted_prompt = format_prompt(
|
@@ -351,6 +352,7 @@ def model_inference(
|
|
351 |
generate_kwargs = dict(
|
352 |
max_new_tokens=5000,
|
353 |
do_sample=True,
|
|
|
354 |
)
|
355 |
# Format the prompt for the language model
|
356 |
formatted_prompt = format_prompt(
|
@@ -389,16 +391,15 @@ def model_inference(
|
|
389 |
}
|
390 |
assert decoding_strategy in [
|
391 |
"Greedy",
|
392 |
-
"
|
393 |
]
|
394 |
|
395 |
if decoding_strategy == "Greedy":
|
396 |
generation_args["do_sample"] = False
|
397 |
-
elif decoding_strategy == "
|
398 |
generation_args["temperature"] = temperature
|
399 |
generation_args["do_sample"] = True
|
400 |
-
generation_args["
|
401 |
-
# Creating model inputs
|
402 |
(
|
403 |
resulting_text,
|
404 |
resulting_images,
|
@@ -440,7 +441,7 @@ FEATURES = datasets.Features(
|
|
440 |
"temperature": datasets.Value("float32"),
|
441 |
"max_new_tokens": datasets.Value("int32"),
|
442 |
"repetition_penalty": datasets.Value("float32"),
|
443 |
-
"
|
444 |
}
|
445 |
)
|
446 |
|
@@ -465,9 +466,9 @@ repetition_penalty = gr.Slider(
|
|
465 |
decoding_strategy = gr.Radio(
|
466 |
[
|
467 |
"Greedy",
|
468 |
-
"
|
469 |
],
|
470 |
-
value="
|
471 |
label="Decoding strategy",
|
472 |
interactive=True,
|
473 |
info="Higher values are equivalent to sampling more low-probability tokens.",
|
@@ -482,14 +483,14 @@ temperature = gr.Slider(
|
|
482 |
label="Sampling temperature",
|
483 |
info="Higher values will produce more diverse outputs.",
|
484 |
)
|
485 |
-
|
486 |
minimum=0.01,
|
487 |
-
maximum=0.
|
488 |
-
value=0.
|
489 |
step=0.01,
|
490 |
visible=True,
|
491 |
interactive=True,
|
492 |
-
label="
|
493 |
info="Higher values are equivalent to sampling more low-probability tokens.",
|
494 |
)
|
495 |
|
|
|
315 |
temperature,
|
316 |
max_new_tokens,
|
317 |
repetition_penalty,
|
318 |
+
min_p,
|
319 |
web_search,
|
320 |
):
|
321 |
# Define generation_args at the beginning of the function
|
|
|
332 |
generate_kwargs = dict(
|
333 |
max_new_tokens=4000,
|
334 |
do_sample=True,
|
335 |
+
min_p=0.08,
|
336 |
)
|
337 |
# Format the prompt for the language model
|
338 |
formatted_prompt = format_prompt(
|
|
|
352 |
generate_kwargs = dict(
|
353 |
max_new_tokens=5000,
|
354 |
do_sample=True,
|
355 |
+
min_p=0.08,
|
356 |
)
|
357 |
# Format the prompt for the language model
|
358 |
formatted_prompt = format_prompt(
|
|
|
391 |
}
|
392 |
assert decoding_strategy in [
|
393 |
"Greedy",
|
394 |
+
"Min P Sampling",
|
395 |
]
|
396 |
|
397 |
if decoding_strategy == "Greedy":
|
398 |
generation_args["do_sample"] = False
|
399 |
+
elif decoding_strategy == "Min P Sampling":
|
400 |
generation_args["temperature"] = temperature
|
401 |
generation_args["do_sample"] = True
|
402 |
+
generation_args["min_p"] = min_p
|
|
|
403 |
(
|
404 |
resulting_text,
|
405 |
resulting_images,
|
|
|
441 |
"temperature": datasets.Value("float32"),
|
442 |
"max_new_tokens": datasets.Value("int32"),
|
443 |
"repetition_penalty": datasets.Value("float32"),
|
444 |
+
"min_p": datasets.Value("int32"),
|
445 |
}
|
446 |
)
|
447 |
|
|
|
466 |
decoding_strategy = gr.Radio(
|
467 |
[
|
468 |
"Greedy",
|
469 |
+
"Min P Sampling",
|
470 |
],
|
471 |
+
value="Min P Sampling",
|
472 |
label="Decoding strategy",
|
473 |
interactive=True,
|
474 |
info="Higher values are equivalent to sampling more low-probability tokens.",
|
|
|
483 |
label="Sampling temperature",
|
484 |
info="Higher values will produce more diverse outputs.",
|
485 |
)
|
486 |
+
min_p = gr.Slider(
|
487 |
minimum=0.01,
|
488 |
+
maximum=0.49,
|
489 |
+
value=0.08,
|
490 |
step=0.01,
|
491 |
visible=True,
|
492 |
interactive=True,
|
493 |
+
label="Min P",
|
494 |
info="Higher values are equivalent to sampling more low-probability tokens.",
|
495 |
)
|
496 |
|