Spaces:
Running
Running
Niki Zhang
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from io import BytesIO
|
|
|
2 |
from math import inf
|
3 |
import os
|
4 |
import base64
|
@@ -26,7 +27,15 @@ import easyocr
|
|
26 |
import re
|
27 |
import edge_tts
|
28 |
from langchain import __version__
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# Print the current version of LangChain
|
31 |
print(f"Current LangChain version: {__version__}")
|
32 |
# import tts
|
@@ -37,7 +46,9 @@ print(f"Current LangChain version: {__version__}")
|
|
37 |
|
38 |
|
39 |
# import spaces #
|
|
|
40 |
|
|
|
41 |
import os
|
42 |
# import uuid
|
43 |
# from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
|
@@ -304,6 +315,56 @@ def make3d(images):
|
|
304 |
###############################################################################
|
305 |
|
306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
###############################################################################
|
308 |
############# this part is for text to image #############
|
309 |
###############################################################################
|
@@ -623,6 +684,14 @@ async def chat_input_callback(*args):
|
|
623 |
|
624 |
|
625 |
def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
626 |
|
627 |
click_state = [[], [], []]
|
628 |
image_input = image_resize(image_input, res=1024)
|
@@ -938,13 +1007,15 @@ async def inference_traject(origin_image,sketcher_image, enable_wiki, language,
|
|
938 |
|
939 |
if trace_type=="Trace+Seg":
|
940 |
input_mask = np.array(out['mask'].convert('P'))
|
941 |
-
image_input = mask_painter(np.array(image_input), input_mask, background_alpha=0
|
|
|
942 |
crop_save_path=out['crop_save_path']
|
943 |
|
944 |
else:
|
945 |
image_input = Image.fromarray(np.array(origin_image))
|
946 |
draw = ImageDraw.Draw(image_input)
|
947 |
draw.rectangle(boxes, outline='red', width=2)
|
|
|
948 |
cropped_image = origin_image.crop(boxes)
|
949 |
cropped_image.save('temp.png')
|
950 |
crop_save_path='temp.png'
|
@@ -977,14 +1048,14 @@ async def inference_traject(origin_image,sketcher_image, enable_wiki, language,
|
|
977 |
try:
|
978 |
audio_output = await texttospeech(read_info, language,autoplay)
|
979 |
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
|
980 |
-
return state, state,image_input,audio_output
|
981 |
|
982 |
|
983 |
except Exception as e:
|
984 |
state = state + [(None, f"Error during TTS prediction: {str(e)}")]
|
985 |
print(f"Error during TTS prediction: {str(e)}")
|
986 |
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
|
987 |
-
return state, state, image_input,audio_output
|
988 |
|
989 |
|
990 |
else:
|
@@ -1290,11 +1361,10 @@ def create_ui():
|
|
1290 |
|
1291 |
with gr.Row():
|
1292 |
|
1293 |
-
with gr.Column():
|
1294 |
with gr.Column(visible=False) as modules_not_need_gpt:
|
1295 |
with gr.Tab("Base(GPT Power)") as base_tab:
|
1296 |
image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
1297 |
-
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
1298 |
with gr.Row():
|
1299 |
name_label_base = gr.Button(value="Name: ")
|
1300 |
artist_label_base = gr.Button(value="Artist: ")
|
@@ -1304,45 +1374,51 @@ def create_ui():
|
|
1304 |
with gr.Tab("Click") as click_tab:
|
1305 |
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
1306 |
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
|
|
1307 |
with gr.Row():
|
1308 |
name_label = gr.Button(value="Name: ")
|
1309 |
artist_label = gr.Button(value="Artist: ")
|
1310 |
year_label = gr.Button(value="Year: ")
|
1311 |
material_label = gr.Button(value="Material: ")
|
1312 |
with gr.Row():
|
1313 |
-
with gr.
|
1314 |
-
|
1315 |
-
|
1316 |
-
|
1317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1318 |
interactive=True)
|
1319 |
-
|
1320 |
-
|
1321 |
-
|
1322 |
-
|
1323 |
-
|
1324 |
-
|
1325 |
-
|
1326 |
-
|
1327 |
-
interactive=True)
|
1328 |
-
|
1329 |
-
|
1330 |
-
value="
|
1331 |
-
label="Clicking Mode",
|
1332 |
-
interactive=True)
|
1333 |
-
with gr.Row():
|
1334 |
-
clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
|
1335 |
-
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
1336 |
|
1337 |
with gr.Tab("Trajectory (beta)") as traj_tab:
|
1338 |
# sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10,
|
1339 |
# elem_id="image_sketcher")
|
1340 |
sketcher_input = gr.ImageEditor(type="pil", interactive=True,
|
1341 |
elem_id="image_sketcher")
|
1342 |
-
|
1343 |
with gr.Row():
|
1344 |
-
submit_button_sketcher = gr.Button(value="Submit", interactive=True)
|
1345 |
clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True)
|
|
|
1346 |
with gr.Row():
|
1347 |
with gr.Row():
|
1348 |
focus_type_sketch = gr.Radio(
|
@@ -1354,9 +1430,9 @@ def create_ui():
|
|
1354 |
choices=["Trace+Seg", "Trace"],
|
1355 |
value="Trace+Seg",
|
1356 |
label="Trace Type",
|
1357 |
-
interactive=True)
|
1358 |
|
1359 |
-
with gr.Column(visible=False) as modules_need_gpt1:
|
1360 |
with gr.Row():
|
1361 |
sentiment = gr.Radio(
|
1362 |
choices=["Positive", "Natural", "Negative"],
|
@@ -1395,7 +1471,7 @@ def create_ui():
|
|
1395 |
|
1396 |
|
1397 |
|
1398 |
-
with gr.Column():
|
1399 |
with gr.Column(visible=True) as module_key_input:
|
1400 |
openai_api_key = gr.Textbox(
|
1401 |
placeholder="Input openAI API key",
|
@@ -1454,7 +1530,7 @@ def create_ui():
|
|
1454 |
|
1455 |
with gr.Column():
|
1456 |
with gr.Column():
|
1457 |
-
gr.Radio([
|
1458 |
gr.Radio(["Oil Painting","Printmaking","Watercolor Painting","Drawing"], label="Art Forms", info="What are the art forms?🎨"),
|
1459 |
gr.Radio(["Renaissance", "Baroque", "Impressionism","Modernism"], label="Period", info="Which art period?⏳"),
|
1460 |
# to be done
|
@@ -1582,20 +1658,9 @@ def create_ui():
|
|
1582 |
# api_name="run",
|
1583 |
# )
|
1584 |
run_button.click(
|
1585 |
-
fn=
|
1586 |
-
inputs=[
|
1587 |
-
|
1588 |
-
negative_prompt,
|
1589 |
-
use_negative_prompt,
|
1590 |
-
seed,
|
1591 |
-
width,
|
1592 |
-
height,
|
1593 |
-
guidance_scale,
|
1594 |
-
num_inference_steps,
|
1595 |
-
randomize_seed,
|
1596 |
-
num_images
|
1597 |
-
],
|
1598 |
-
outputs=[result, seed]
|
1599 |
)
|
1600 |
|
1601 |
###############################################################################
|
@@ -1825,12 +1890,12 @@ def create_ui():
|
|
1825 |
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
|
1826 |
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
|
1827 |
|
1828 |
-
image_input.upload(upload_callback, [image_input, state, visual_chatgpt, openai_api_key],
|
1829 |
-
|
1830 |
-
|
1831 |
-
sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt, openai_api_key],
|
1832 |
-
|
1833 |
-
|
1834 |
chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
|
1835 |
[chatbot, state, aux_state,output_audio])
|
1836 |
chat_input.submit(lambda: "", None, chat_input)
|
@@ -1904,7 +1969,7 @@ def create_ui():
|
|
1904 |
origin_image,sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
1905 |
original_size, input_size, text_refiner,focus_type_sketch,paragraph,openai_api_key,auto_play,Input_sketch
|
1906 |
],
|
1907 |
-
outputs=[chatbot, state, sketcher_input,output_audio],
|
1908 |
show_progress=False, queue=True
|
1909 |
)
|
1910 |
|
|
|
1 |
from io import BytesIO
|
2 |
+
import io
|
3 |
from math import inf
|
4 |
import os
|
5 |
import base64
|
|
|
27 |
import re
|
28 |
import edge_tts
|
29 |
from langchain import __version__
|
30 |
+
import torch
|
31 |
+
import gradio as gr
|
32 |
+
from transformers import AutoProcessor, SiglipModel
|
33 |
+
import faiss
|
34 |
+
from huggingface_hub import hf_hub_download
|
35 |
+
from datasets import load_dataset
|
36 |
+
import pandas as pd
|
37 |
+
import requests
|
38 |
+
import spaces
|
39 |
# Print the current version of LangChain
|
40 |
print(f"Current LangChain version: {__version__}")
|
41 |
# import tts
|
|
|
46 |
|
47 |
|
48 |
# import spaces #
|
49 |
+
import threading
|
50 |
|
51 |
+
lock = threading.Lock()
|
52 |
import os
|
53 |
# import uuid
|
54 |
# from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
|
|
|
315 |
###############################################################################
|
316 |
|
317 |
|
318 |
+
###############################################################################
|
319 |
+
############# This part is for sCLIP #############
|
320 |
+
###############################################################################
|
321 |
+
|
322 |
+
# download model and dataset
|
323 |
+
hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k_latest.index", local_dir="./")
|
324 |
+
hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k_latest.csv", local_dir="./")
|
325 |
+
|
326 |
+
# read index, dataset and load siglip model and processor
|
327 |
+
index = faiss.read_index("./siglip_10k_latest.index")
|
328 |
+
df = pd.read_csv("./wikiart_10k_latest.csv")
|
329 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
330 |
+
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
331 |
+
slipmodel = SiglipModel.from_pretrained("google/siglip-base-patch16-224").to(device)
|
332 |
+
|
333 |
+
|
334 |
+
def read_image_from_url(url):
|
335 |
+
response = requests.get(url)
|
336 |
+
img = Image.open(BytesIO(response.content)).convert("RGB")
|
337 |
+
return img
|
338 |
+
|
339 |
+
#@spaces.GPU
|
340 |
+
def extract_features_siglip(image):
|
341 |
+
with torch.no_grad():
|
342 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
343 |
+
image_features = slipmodel.get_image_features(**inputs)
|
344 |
+
return image_features
|
345 |
+
|
346 |
+
@spaces.GPU
|
347 |
+
def infer(image_path):
|
348 |
+
input_image = Image.open(image_path).convert("RGB")
|
349 |
+
input_features = extract_features_siglip(input_image.convert("RGB"))
|
350 |
+
input_features = input_features.detach().cpu().numpy()
|
351 |
+
input_features = np.float32(input_features)
|
352 |
+
faiss.normalize_L2(input_features)
|
353 |
+
distances, indices = index.search(input_features, 3)
|
354 |
+
gallery_output = []
|
355 |
+
for i,v in enumerate(indices[0]):
|
356 |
+
sim = -distances[0][i]
|
357 |
+
image_url = df.iloc[v]["Link"]
|
358 |
+
img_retrieved = read_image_from_url(image_url)
|
359 |
+
gallery_output.append(img_retrieved)
|
360 |
+
return gallery_output
|
361 |
+
|
362 |
+
|
363 |
+
###############################################################################
|
364 |
+
############# Above part is for sCLIP #############
|
365 |
+
###############################################################################
|
366 |
+
|
367 |
+
|
368 |
###############################################################################
|
369 |
############# this part is for text to image #############
|
370 |
###############################################################################
|
|
|
684 |
|
685 |
|
686 |
def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English"):
|
687 |
+
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
|
688 |
+
image_input = image_input['background']
|
689 |
+
|
690 |
+
if isinstance(image_input, str):
|
691 |
+
image_input = Image.open(io.BytesIO(base64.b64decode(image_input)))
|
692 |
+
elif isinstance(image_input, bytes):
|
693 |
+
image_input = Image.open(io.BytesIO(image_input))
|
694 |
+
|
695 |
|
696 |
click_state = [[], [], []]
|
697 |
image_input = image_resize(image_input, res=1024)
|
|
|
1007 |
|
1008 |
if trace_type=="Trace+Seg":
|
1009 |
input_mask = np.array(out['mask'].convert('P'))
|
1010 |
+
image_input = mask_painter(np.array(image_input), input_mask, background_alpha=0)
|
1011 |
+
d3_input=mask_painter(np.array(image_input), input_mask)
|
1012 |
crop_save_path=out['crop_save_path']
|
1013 |
|
1014 |
else:
|
1015 |
image_input = Image.fromarray(np.array(origin_image))
|
1016 |
draw = ImageDraw.Draw(image_input)
|
1017 |
draw.rectangle(boxes, outline='red', width=2)
|
1018 |
+
d3_input=image_input
|
1019 |
cropped_image = origin_image.crop(boxes)
|
1020 |
cropped_image.save('temp.png')
|
1021 |
crop_save_path='temp.png'
|
|
|
1048 |
try:
|
1049 |
audio_output = await texttospeech(read_info, language,autoplay)
|
1050 |
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
|
1051 |
+
return state, state,image_input,audio_output,crop_save_path,d3_input
|
1052 |
|
1053 |
|
1054 |
except Exception as e:
|
1055 |
state = state + [(None, f"Error during TTS prediction: {str(e)}")]
|
1056 |
print(f"Error during TTS prediction: {str(e)}")
|
1057 |
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
|
1058 |
+
return state, state, image_input,audio_output,crop_save_path,d3_input
|
1059 |
|
1060 |
|
1061 |
else:
|
|
|
1361 |
|
1362 |
with gr.Row():
|
1363 |
|
1364 |
+
with gr.Column(scale=6):
|
1365 |
with gr.Column(visible=False) as modules_not_need_gpt:
|
1366 |
with gr.Tab("Base(GPT Power)") as base_tab:
|
1367 |
image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
|
|
1368 |
with gr.Row():
|
1369 |
name_label_base = gr.Button(value="Name: ")
|
1370 |
artist_label_base = gr.Button(value="Artist: ")
|
|
|
1374 |
with gr.Tab("Click") as click_tab:
|
1375 |
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
1376 |
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
1377 |
+
# example_image_click = gr.Image(type="pil", interactive=False, visible=False)
|
1378 |
with gr.Row():
|
1379 |
name_label = gr.Button(value="Name: ")
|
1380 |
artist_label = gr.Button(value="Artist: ")
|
1381 |
year_label = gr.Button(value="Year: ")
|
1382 |
material_label = gr.Button(value="Material: ")
|
1383 |
with gr.Row():
|
1384 |
+
with gr.Column():
|
1385 |
+
with gr.Row():
|
1386 |
+
focus_type = gr.Radio(
|
1387 |
+
choices=["CFV-D", "CFV-DA", "CFV-DAI","PFV-DDA"],
|
1388 |
+
value="CFV-D",
|
1389 |
+
label="Information Type",
|
1390 |
+
interactive=True,
|
1391 |
+
scale=4)
|
1392 |
+
|
1393 |
+
with gr.Row():
|
1394 |
+
point_prompt = gr.Radio(
|
1395 |
+
choices=["Positive", "Negative"],
|
1396 |
+
value="Positive",
|
1397 |
+
label="Point Prompt",
|
1398 |
+
scale=5,
|
1399 |
interactive=True)
|
1400 |
+
click_mode = gr.Radio(
|
1401 |
+
choices=["Continuous", "Single"],
|
1402 |
+
value="Continuous",
|
1403 |
+
label="Clicking Mode",
|
1404 |
+
scale=5,
|
1405 |
+
interactive=True)
|
1406 |
+
with gr.Column():
|
1407 |
+
with gr.Row():
|
1408 |
+
submit_button_click=gr.Button(value="Submit", interactive=True,variant='primary',scale=2)
|
1409 |
+
with gr.Row():
|
1410 |
+
clear_button_click = gr.Button(value="Clear Clicks", interactive=True,scale=2)
|
1411 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True,scale=2)
|
|
|
|
|
|
|
|
|
|
|
1412 |
|
1413 |
with gr.Tab("Trajectory (beta)") as traj_tab:
|
1414 |
# sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10,
|
1415 |
# elem_id="image_sketcher")
|
1416 |
sketcher_input = gr.ImageEditor(type="pil", interactive=True,
|
1417 |
elem_id="image_sketcher")
|
1418 |
+
# example_image_traj = gr.Image(type="pil", interactive=False, visible=False)
|
1419 |
with gr.Row():
|
|
|
1420 |
clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True)
|
1421 |
+
submit_button_sketcher = gr.Button(value="Submit", interactive=True)
|
1422 |
with gr.Row():
|
1423 |
with gr.Row():
|
1424 |
focus_type_sketch = gr.Radio(
|
|
|
1430 |
choices=["Trace+Seg", "Trace"],
|
1431 |
value="Trace+Seg",
|
1432 |
label="Trace Type",
|
1433 |
+
interactive=True)
|
1434 |
|
1435 |
+
with gr.Column(visible=False,scale=4) as modules_need_gpt1:
|
1436 |
with gr.Row():
|
1437 |
sentiment = gr.Radio(
|
1438 |
choices=["Positive", "Natural", "Negative"],
|
|
|
1471 |
|
1472 |
|
1473 |
|
1474 |
+
with gr.Column(scale=5):
|
1475 |
with gr.Column(visible=True) as module_key_input:
|
1476 |
openai_api_key = gr.Textbox(
|
1477 |
placeholder="Input openAI API key",
|
|
|
1530 |
|
1531 |
with gr.Column():
|
1532 |
with gr.Column():
|
1533 |
+
gr.Radio(["Other Paintings by the Artist"], label="Artist", info="Who is the artist?🧑🎨"),
|
1534 |
gr.Radio(["Oil Painting","Printmaking","Watercolor Painting","Drawing"], label="Art Forms", info="What are the art forms?🎨"),
|
1535 |
gr.Radio(["Renaissance", "Baroque", "Impressionism","Modernism"], label="Period", info="Which art period?⏳"),
|
1536 |
# to be done
|
|
|
1658 |
# api_name="run",
|
1659 |
# )
|
1660 |
run_button.click(
|
1661 |
+
fn=infer,
|
1662 |
+
inputs=[new_crop_save_path],
|
1663 |
+
outputs=[result]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1664 |
)
|
1665 |
|
1666 |
###############################################################################
|
|
|
1890 |
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
|
1891 |
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
|
1892 |
|
1893 |
+
# image_input.upload(upload_callback, [image_input, state, visual_chatgpt, openai_api_key],
|
1894 |
+
# [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
|
1895 |
+
# image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
|
1896 |
+
# sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt, openai_api_key],
|
1897 |
+
# [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
|
1898 |
+
# image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
|
1899 |
chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
|
1900 |
[chatbot, state, aux_state,output_audio])
|
1901 |
chat_input.submit(lambda: "", None, chat_input)
|
|
|
1969 |
origin_image,sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
1970 |
original_size, input_size, text_refiner,focus_type_sketch,paragraph,openai_api_key,auto_play,Input_sketch
|
1971 |
],
|
1972 |
+
outputs=[chatbot, state, sketcher_input,output_audio,new_crop_save_path,input_image],
|
1973 |
show_progress=False, queue=True
|
1974 |
)
|
1975 |
|