Spaces:
Running
Running
update reqs + small refactor
Browse files- app_dialogue.py +119 -118
- requirements.txt +1 -3
app_dialogue.py
CHANGED
@@ -546,6 +546,125 @@ def expand_layout():
|
|
546 |
return gr.Column(scale=2), gr.Gallery(height=682)
|
547 |
|
548 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
textbox = gr.Textbox(
|
550 |
placeholder="Upload an image and ask the AI to create a meme!",
|
551 |
show_label=False,
|
@@ -764,115 +883,6 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
|
|
764 |
with gr.Row():
|
765 |
chatbot.render()
|
766 |
|
767 |
-
def generate_meme(
|
768 |
-
i,
|
769 |
-
client,
|
770 |
-
query,
|
771 |
-
image,
|
772 |
-
font_meme_text,
|
773 |
-
all_caps_meme_text,
|
774 |
-
text_at_the_top,
|
775 |
-
generation_args,
|
776 |
-
):
|
777 |
-
try:
|
778 |
-
text = client.generate(prompt=query, **generation_args).generated_text
|
779 |
-
except Exception as e:
|
780 |
-
logger.error(f"Error {e} while generating meme text")
|
781 |
-
text = ""
|
782 |
-
if image is not None and text != "":
|
783 |
-
meme_image = make_meme_image(
|
784 |
-
image=image,
|
785 |
-
text=text,
|
786 |
-
font_meme_text=font_meme_text,
|
787 |
-
all_caps_meme_text=all_caps_meme_text,
|
788 |
-
text_at_the_top=text_at_the_top,
|
789 |
-
)
|
790 |
-
return meme_image
|
791 |
-
else:
|
792 |
-
return None
|
793 |
-
|
794 |
-
def model_inference(
|
795 |
-
model_selector,
|
796 |
-
system_prompt,
|
797 |
-
user_prompt_str,
|
798 |
-
chat_history,
|
799 |
-
image,
|
800 |
-
decoding_strategy,
|
801 |
-
temperature,
|
802 |
-
max_new_tokens,
|
803 |
-
repetition_penalty,
|
804 |
-
top_p,
|
805 |
-
all_caps_meme_text,
|
806 |
-
text_at_the_top,
|
807 |
-
font_meme_text,
|
808 |
-
):
|
809 |
-
chat_history = []
|
810 |
-
if user_prompt_str.strip() == "" and image is None:
|
811 |
-
return "", None, chat_history
|
812 |
-
|
813 |
-
system_prompt = ast.literal_eval(system_prompt)
|
814 |
-
(
|
815 |
-
formated_prompt_list,
|
816 |
-
user_prompt_list,
|
817 |
-
) = format_user_prompt_with_im_history_and_system_conditioning(
|
818 |
-
system_prompt=system_prompt,
|
819 |
-
current_user_prompt_str=user_prompt_str.strip(),
|
820 |
-
current_image=image,
|
821 |
-
history=chat_history,
|
822 |
-
)
|
823 |
-
|
824 |
-
client_endpoint = API_PATHS[model_selector]
|
825 |
-
client = Client(
|
826 |
-
base_url=client_endpoint,
|
827 |
-
headers={"x-use-cache": "0", "Authorization": f"Bearer {API_TOKEN}"},
|
828 |
-
timeout=45,
|
829 |
-
)
|
830 |
-
|
831 |
-
# Common parameters to all decoding strategies
|
832 |
-
# This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
|
833 |
-
generation_args = {
|
834 |
-
"max_new_tokens": max_new_tokens,
|
835 |
-
"repetition_penalty": repetition_penalty,
|
836 |
-
"stop_sequences": EOS_STRINGS,
|
837 |
-
}
|
838 |
-
|
839 |
-
assert decoding_strategy in [
|
840 |
-
"Greedy",
|
841 |
-
"Top P Sampling",
|
842 |
-
]
|
843 |
-
if decoding_strategy == "Greedy":
|
844 |
-
generation_args["do_sample"] = False
|
845 |
-
elif decoding_strategy == "Top P Sampling":
|
846 |
-
generation_args["temperature"] = temperature
|
847 |
-
generation_args["do_sample"] = True
|
848 |
-
generation_args["top_p"] = top_p
|
849 |
-
|
850 |
-
chat_history.append([prompt_list_to_markdown(user_prompt_list), ""])
|
851 |
-
|
852 |
-
query = prompt_list_to_tgi_input(formated_prompt_list)
|
853 |
-
all_meme_images = []
|
854 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
855 |
-
futures = [
|
856 |
-
executor.submit(
|
857 |
-
generate_meme,
|
858 |
-
i,
|
859 |
-
client,
|
860 |
-
query,
|
861 |
-
image,
|
862 |
-
font_meme_text,
|
863 |
-
all_caps_meme_text,
|
864 |
-
text_at_the_top,
|
865 |
-
generation_args,
|
866 |
-
)
|
867 |
-
for i in range(4)
|
868 |
-
]
|
869 |
-
|
870 |
-
for future in concurrent.futures.as_completed(futures):
|
871 |
-
meme_image = future.result(timeout=45)
|
872 |
-
if meme_image:
|
873 |
-
all_meme_images.append(meme_image)
|
874 |
-
return user_prompt_str, all_meme_images, chat_history
|
875 |
-
|
876 |
gr.on(
|
877 |
triggers=[
|
878 |
textbox.submit,
|
@@ -906,15 +916,6 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
|
|
906 |
outputs=[textbox, generated_memes_gallery, chatbot],
|
907 |
)
|
908 |
|
909 |
-
def remove_last_turn(chat_history):
|
910 |
-
if len(chat_history) == 0:
|
911 |
-
return chat_history, "", ""
|
912 |
-
last_interaction = chat_history[-1]
|
913 |
-
chat_history = chat_history[:-1]
|
914 |
-
chat_update = chat_history
|
915 |
-
text_update = last_interaction[0]
|
916 |
-
return chat_update, text_update, ""
|
917 |
-
|
918 |
regenerate_btn.click(
|
919 |
fn=remove_last_turn,
|
920 |
inputs=chatbot,
|
|
|
546 |
return gr.Column(scale=2), gr.Gallery(height=682)
|
547 |
|
548 |
|
549 |
+
def generate_meme(
|
550 |
+
client,
|
551 |
+
query,
|
552 |
+
image,
|
553 |
+
font_meme_text,
|
554 |
+
all_caps_meme_text,
|
555 |
+
text_at_the_top,
|
556 |
+
generation_args,
|
557 |
+
):
|
558 |
+
try:
|
559 |
+
text = client.generate(prompt=query, **generation_args).generated_text
|
560 |
+
except Exception as e:
|
561 |
+
logger.error(f"Error {e} while generating meme text")
|
562 |
+
text = ""
|
563 |
+
if image is not None and text != "":
|
564 |
+
meme_image = make_meme_image(
|
565 |
+
image=image,
|
566 |
+
text=text,
|
567 |
+
font_meme_text=font_meme_text,
|
568 |
+
all_caps_meme_text=all_caps_meme_text,
|
569 |
+
text_at_the_top=text_at_the_top,
|
570 |
+
)
|
571 |
+
return meme_image
|
572 |
+
else:
|
573 |
+
return None
|
574 |
+
|
575 |
+
|
576 |
+
def model_inference(
|
577 |
+
model_selector,
|
578 |
+
system_prompt,
|
579 |
+
user_prompt_str,
|
580 |
+
chat_history,
|
581 |
+
image,
|
582 |
+
decoding_strategy,
|
583 |
+
temperature,
|
584 |
+
max_new_tokens,
|
585 |
+
repetition_penalty,
|
586 |
+
top_p,
|
587 |
+
all_caps_meme_text,
|
588 |
+
text_at_the_top,
|
589 |
+
font_meme_text,
|
590 |
+
):
|
591 |
+
chat_history = []
|
592 |
+
if user_prompt_str.strip() == "" and image is None:
|
593 |
+
return "", None, chat_history
|
594 |
+
|
595 |
+
system_prompt = ast.literal_eval(system_prompt)
|
596 |
+
(
|
597 |
+
formated_prompt_list,
|
598 |
+
user_prompt_list,
|
599 |
+
) = format_user_prompt_with_im_history_and_system_conditioning(
|
600 |
+
system_prompt=system_prompt,
|
601 |
+
current_user_prompt_str=user_prompt_str.strip(),
|
602 |
+
current_image=image,
|
603 |
+
history=chat_history,
|
604 |
+
)
|
605 |
+
|
606 |
+
client_endpoint = API_PATHS[model_selector]
|
607 |
+
client = Client(
|
608 |
+
base_url=client_endpoint,
|
609 |
+
headers={"x-use-cache": "0", "Authorization": f"Bearer {API_TOKEN}"},
|
610 |
+
timeout=45,
|
611 |
+
)
|
612 |
+
|
613 |
+
# Common parameters to all decoding strategies
|
614 |
+
# This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
|
615 |
+
generation_args = {
|
616 |
+
"max_new_tokens": max_new_tokens,
|
617 |
+
"repetition_penalty": repetition_penalty,
|
618 |
+
"stop_sequences": EOS_STRINGS,
|
619 |
+
}
|
620 |
+
|
621 |
+
assert decoding_strategy in [
|
622 |
+
"Greedy",
|
623 |
+
"Top P Sampling",
|
624 |
+
]
|
625 |
+
if decoding_strategy == "Greedy":
|
626 |
+
generation_args["do_sample"] = False
|
627 |
+
elif decoding_strategy == "Top P Sampling":
|
628 |
+
generation_args["temperature"] = temperature
|
629 |
+
generation_args["do_sample"] = True
|
630 |
+
generation_args["top_p"] = top_p
|
631 |
+
|
632 |
+
chat_history.append([prompt_list_to_markdown(user_prompt_list), ""])
|
633 |
+
|
634 |
+
query = prompt_list_to_tgi_input(formated_prompt_list)
|
635 |
+
all_meme_images = []
|
636 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
637 |
+
futures = [
|
638 |
+
executor.submit(
|
639 |
+
generate_meme,
|
640 |
+
client,
|
641 |
+
query,
|
642 |
+
image,
|
643 |
+
font_meme_text,
|
644 |
+
all_caps_meme_text,
|
645 |
+
text_at_the_top,
|
646 |
+
generation_args,
|
647 |
+
)
|
648 |
+
for i in range(4)
|
649 |
+
]
|
650 |
+
|
651 |
+
for future in concurrent.futures.as_completed(futures):
|
652 |
+
meme_image = future.result(timeout=45)
|
653 |
+
if meme_image:
|
654 |
+
all_meme_images.append(meme_image)
|
655 |
+
return user_prompt_str, all_meme_images, chat_history
|
656 |
+
|
657 |
+
|
658 |
+
def remove_last_turn(chat_history):
|
659 |
+
if len(chat_history) == 0:
|
660 |
+
return chat_history, "", ""
|
661 |
+
last_interaction = chat_history[-1]
|
662 |
+
chat_history = chat_history[:-1]
|
663 |
+
chat_update = chat_history
|
664 |
+
text_update = last_interaction[0]
|
665 |
+
return chat_update, text_update, ""
|
666 |
+
|
667 |
+
|
668 |
textbox = gr.Textbox(
|
669 |
placeholder="Upload an image and ask the AI to create a meme!",
|
670 |
show_label=False,
|
|
|
883 |
with gr.Row():
|
884 |
chatbot.render()
|
885 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
886 |
gr.on(
|
887 |
triggers=[
|
888 |
textbox.submit,
|
|
|
916 |
outputs=[textbox, generated_memes_gallery, chatbot],
|
917 |
)
|
918 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
919 |
regenerate_btn.click(
|
920 |
fn=remove_last_turn,
|
921 |
inputs=chatbot,
|
requirements.txt
CHANGED
@@ -9,10 +9,8 @@ opencv-python
|
|
9 |
numpy
|
10 |
accelerate
|
11 |
joblib
|
12 |
-
deepspeed
|
13 |
parameterized
|
14 |
einops
|
15 |
pynvml
|
16 |
sentencepiece
|
17 |
-
text_generation
|
18 |
-
https://gradio-builds.s3.amazonaws.com/2060bfe3e7eb57fb9b5c8695ebfc900469263d1f/gradio-3.46.0-py3-none-any.whl
|
|
|
9 |
numpy
|
10 |
accelerate
|
11 |
joblib
|
|
|
12 |
parameterized
|
13 |
einops
|
14 |
pynvml
|
15 |
sentencepiece
|
16 |
+
text_generation
|
|