Spaces:
Running
Running
add multiprocessing
Browse files- app_dialogue.py +43 -14
app_dialogue.py
CHANGED
@@ -10,6 +10,7 @@ from typing import List, Optional, Tuple
|
|
10 |
from urllib.parse import urlparse
|
11 |
from PIL import Image, ImageDraw, ImageFont
|
12 |
|
|
|
13 |
import random
|
14 |
import gradio as gr
|
15 |
import PIL
|
@@ -777,6 +778,28 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
|
|
777 |
with gr.Row():
|
778 |
chatbot.render()
|
779 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
780 |
def model_inference(
|
781 |
model_selector,
|
782 |
system_prompt,
|
@@ -849,21 +872,27 @@ with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base(), css=css) as de
|
|
849 |
|
850 |
query = prompt_list_to_tgi_input(formated_prompt_list)
|
851 |
all_meme_images = []
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
|
|
|
|
|
|
861 |
)
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
|
|
|
|
|
|
867 |
|
868 |
gr.on(
|
869 |
triggers=[textbox.submit, imagebox.upload, submit_btn.click],
|
|
|
10 |
from urllib.parse import urlparse
|
11 |
from PIL import Image, ImageDraw, ImageFont
|
12 |
|
13 |
+
import concurrent.futures
|
14 |
import random
|
15 |
import gradio as gr
|
16 |
import PIL
|
|
|
778 |
with gr.Row():
|
779 |
chatbot.render()
|
780 |
|
781 |
+
def generate_meme(
|
782 |
+
i,
|
783 |
+
client,
|
784 |
+
query,
|
785 |
+
image,
|
786 |
+
font_meme_text,
|
787 |
+
all_caps_meme_text,
|
788 |
+
text_at_the_top,
|
789 |
+
generation_args,
|
790 |
+
):
|
791 |
+
text = client.generate(prompt=query, **generation_args).generated_text
|
792 |
+
if image is not None and text != "":
|
793 |
+
meme_image = make_meme_image(
|
794 |
+
image=image,
|
795 |
+
text=text,
|
796 |
+
font_meme_text=font_meme_text,
|
797 |
+
all_caps_meme_text=all_caps_meme_text,
|
798 |
+
text_at_the_top=text_at_the_top,
|
799 |
+
)
|
800 |
+
meme_image = pil_to_temp_file(meme_image)
|
801 |
+
return meme_image
|
802 |
+
|
803 |
def model_inference(
|
804 |
model_selector,
|
805 |
system_prompt,
|
|
|
872 |
|
873 |
query = prompt_list_to_tgi_input(formated_prompt_list)
|
874 |
all_meme_images = []
|
875 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
876 |
+
futures = [
|
877 |
+
executor.submit(
|
878 |
+
generate_meme,
|
879 |
+
i,
|
880 |
+
client,
|
881 |
+
query,
|
882 |
+
image,
|
883 |
+
font_meme_text,
|
884 |
+
all_caps_meme_text,
|
885 |
+
text_at_the_top,
|
886 |
+
generation_args,
|
887 |
)
|
888 |
+
for i in range(4)
|
889 |
+
]
|
890 |
+
|
891 |
+
for future in concurrent.futures.as_completed(futures):
|
892 |
+
meme_image = future.result()
|
893 |
+
if meme_image:
|
894 |
+
all_meme_images.append(meme_image)
|
895 |
+
return user_prompt_str, all_meme_images, chat_history
|
896 |
|
897 |
gr.on(
|
898 |
triggers=[textbox.submit, imagebox.upload, submit_btn.click],
|