Niki Zhang commited on
Commit
e3ed104
·
verified ·
1 Parent(s): 16e7f13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -55
app.py CHANGED
@@ -1,4 +1,5 @@
1
  from io import BytesIO
 
2
  from math import inf
3
  import os
4
  import base64
@@ -26,7 +27,15 @@ import easyocr
26
  import re
27
  import edge_tts
28
  from langchain import __version__
29
-
 
 
 
 
 
 
 
 
30
  # Print the current version of LangChain
31
  print(f"Current LangChain version: {__version__}")
32
  # import tts
@@ -37,7 +46,9 @@ print(f"Current LangChain version: {__version__}")
37
 
38
 
39
  # import spaces #
 
40
 
 
41
  import os
42
  # import uuid
43
  # from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
@@ -304,6 +315,56 @@ def make3d(images):
304
  ###############################################################################
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  ###############################################################################
308
  ############# this part is for text to image #############
309
  ###############################################################################
@@ -623,6 +684,14 @@ async def chat_input_callback(*args):
623
 
624
 
625
  def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English"):
 
 
 
 
 
 
 
 
626
 
627
  click_state = [[], [], []]
628
  image_input = image_resize(image_input, res=1024)
@@ -938,13 +1007,15 @@ async def inference_traject(origin_image,sketcher_image, enable_wiki, language,
938
 
939
  if trace_type=="Trace+Seg":
940
  input_mask = np.array(out['mask'].convert('P'))
941
- image_input = mask_painter(np.array(image_input), input_mask, background_alpha=0 )
 
942
  crop_save_path=out['crop_save_path']
943
 
944
  else:
945
  image_input = Image.fromarray(np.array(origin_image))
946
  draw = ImageDraw.Draw(image_input)
947
  draw.rectangle(boxes, outline='red', width=2)
 
948
  cropped_image = origin_image.crop(boxes)
949
  cropped_image.save('temp.png')
950
  crop_save_path='temp.png'
@@ -977,14 +1048,14 @@ async def inference_traject(origin_image,sketcher_image, enable_wiki, language,
977
  try:
978
  audio_output = await texttospeech(read_info, language,autoplay)
979
  # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
980
- return state, state,image_input,audio_output
981
 
982
 
983
  except Exception as e:
984
  state = state + [(None, f"Error during TTS prediction: {str(e)}")]
985
  print(f"Error during TTS prediction: {str(e)}")
986
  # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
987
- return state, state, image_input,audio_output
988
 
989
 
990
  else:
@@ -1290,11 +1361,10 @@ def create_ui():
1290
 
1291
  with gr.Row():
1292
 
1293
- with gr.Column():
1294
  with gr.Column(visible=False) as modules_not_need_gpt:
1295
  with gr.Tab("Base(GPT Power)") as base_tab:
1296
  image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1297
- example_image = gr.Image(type="pil", interactive=False, visible=False)
1298
  with gr.Row():
1299
  name_label_base = gr.Button(value="Name: ")
1300
  artist_label_base = gr.Button(value="Artist: ")
@@ -1304,45 +1374,51 @@ def create_ui():
1304
  with gr.Tab("Click") as click_tab:
1305
  image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1306
  example_image = gr.Image(type="pil", interactive=False, visible=False)
 
1307
  with gr.Row():
1308
  name_label = gr.Button(value="Name: ")
1309
  artist_label = gr.Button(value="Artist: ")
1310
  year_label = gr.Button(value="Year: ")
1311
  material_label = gr.Button(value="Material: ")
1312
  with gr.Row():
1313
- with gr.Row():
1314
- focus_type = gr.Radio(
1315
- choices=["CFV-D", "CFV-DA", "CFV-DAI","PFV-DDA"],
1316
- value="CFV-D",
1317
- label="Information Type",
 
 
 
 
 
 
 
 
 
 
1318
  interactive=True)
1319
- with gr.Row():
1320
- submit_button_click=gr.Button(value="Submit", interactive=True,variant='primary',size="sm")
1321
- with gr.Row():
1322
- with gr.Row():
1323
- point_prompt = gr.Radio(
1324
- choices=["Positive", "Negative"],
1325
- value="Positive",
1326
- label="Point Prompt",
1327
- interactive=True)
1328
- click_mode = gr.Radio(
1329
- choices=["Continuous", "Single"],
1330
- value="Continuous",
1331
- label="Clicking Mode",
1332
- interactive=True)
1333
- with gr.Row():
1334
- clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
1335
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
1336
 
1337
  with gr.Tab("Trajectory (beta)") as traj_tab:
1338
  # sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10,
1339
  # elem_id="image_sketcher")
1340
  sketcher_input = gr.ImageEditor(type="pil", interactive=True,
1341
  elem_id="image_sketcher")
1342
- example_image = gr.Image(type="pil", interactive=False, visible=False)
1343
  with gr.Row():
1344
- submit_button_sketcher = gr.Button(value="Submit", interactive=True)
1345
  clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True)
 
1346
  with gr.Row():
1347
  with gr.Row():
1348
  focus_type_sketch = gr.Radio(
@@ -1354,9 +1430,9 @@ def create_ui():
1354
  choices=["Trace+Seg", "Trace"],
1355
  value="Trace+Seg",
1356
  label="Trace Type",
1357
- interactive=True)
1358
 
1359
- with gr.Column(visible=False) as modules_need_gpt1:
1360
  with gr.Row():
1361
  sentiment = gr.Radio(
1362
  choices=["Positive", "Natural", "Negative"],
@@ -1395,7 +1471,7 @@ def create_ui():
1395
 
1396
 
1397
 
1398
- with gr.Column():
1399
  with gr.Column(visible=True) as module_key_input:
1400
  openai_api_key = gr.Textbox(
1401
  placeholder="Input openAI API key",
@@ -1454,7 +1530,7 @@ def create_ui():
1454
 
1455
  with gr.Column():
1456
  with gr.Column():
1457
- gr.Radio([artist], label="Artist", info="Who is the artist?🧑‍🎨"),
1458
  gr.Radio(["Oil Painting","Printmaking","Watercolor Painting","Drawing"], label="Art Forms", info="What are the art forms?🎨"),
1459
  gr.Radio(["Renaissance", "Baroque", "Impressionism","Modernism"], label="Period", info="Which art period?⏳"),
1460
  # to be done
@@ -1582,20 +1658,9 @@ def create_ui():
1582
  # api_name="run",
1583
  # )
1584
  run_button.click(
1585
- fn=generate,
1586
- inputs=[
1587
- prompt,
1588
- negative_prompt,
1589
- use_negative_prompt,
1590
- seed,
1591
- width,
1592
- height,
1593
- guidance_scale,
1594
- num_inference_steps,
1595
- randomize_seed,
1596
- num_images
1597
- ],
1598
- outputs=[result, seed]
1599
  )
1600
 
1601
  ###############################################################################
@@ -1825,12 +1890,12 @@ def create_ui():
1825
  [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1826
  image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1827
 
1828
- image_input.upload(upload_callback, [image_input, state, visual_chatgpt, openai_api_key],
1829
- [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1830
- image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1831
- sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt, openai_api_key],
1832
- [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1833
- image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1834
  chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
1835
  [chatbot, state, aux_state,output_audio])
1836
  chat_input.submit(lambda: "", None, chat_input)
@@ -1904,7 +1969,7 @@ def create_ui():
1904
  origin_image,sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
1905
  original_size, input_size, text_refiner,focus_type_sketch,paragraph,openai_api_key,auto_play,Input_sketch
1906
  ],
1907
- outputs=[chatbot, state, sketcher_input,output_audio],
1908
  show_progress=False, queue=True
1909
  )
1910
 
 
1
  from io import BytesIO
2
+ import io
3
  from math import inf
4
  import os
5
  import base64
 
27
  import re
28
  import edge_tts
29
  from langchain import __version__
30
+ import torch
31
+ import gradio as gr
32
+ from transformers import AutoProcessor, SiglipModel
33
+ import faiss
34
+ from huggingface_hub import hf_hub_download
35
+ from datasets import load_dataset
36
+ import pandas as pd
37
+ import requests
38
+ import spaces
39
  # Print the current version of LangChain
40
  print(f"Current LangChain version: {__version__}")
41
  # import tts
 
46
 
47
 
48
  # import spaces #
49
+ import threading
50
 
51
+ lock = threading.Lock()
52
  import os
53
  # import uuid
54
  # from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
 
315
  ###############################################################################
316
 
317
 
318
+ ###############################################################################
319
+ ############# This part is for sCLIP #############
320
+ ###############################################################################
321
+
322
+ # download model and dataset
323
+ hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k_latest.index", local_dir="./")
324
+ hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k_latest.csv", local_dir="./")
325
+
326
+ # read index, dataset and load siglip model and processor
327
+ index = faiss.read_index("./siglip_10k_latest.index")
328
+ df = pd.read_csv("./wikiart_10k_latest.csv")
329
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
330
+ processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
331
+ slipmodel = SiglipModel.from_pretrained("google/siglip-base-patch16-224").to(device)
332
+
333
+
334
+ def read_image_from_url(url):
335
+ response = requests.get(url)
336
+ img = Image.open(BytesIO(response.content)).convert("RGB")
337
+ return img
338
+
339
+ #@spaces.GPU
340
+ def extract_features_siglip(image):
341
+ with torch.no_grad():
342
+ inputs = processor(images=image, return_tensors="pt").to(device)
343
+ image_features = slipmodel.get_image_features(**inputs)
344
+ return image_features
345
+
346
+ @spaces.GPU
347
+ def infer(image_path):
348
+ input_image = Image.open(image_path).convert("RGB")
349
+ input_features = extract_features_siglip(input_image.convert("RGB"))
350
+ input_features = input_features.detach().cpu().numpy()
351
+ input_features = np.float32(input_features)
352
+ faiss.normalize_L2(input_features)
353
+ distances, indices = index.search(input_features, 3)
354
+ gallery_output = []
355
+ for i,v in enumerate(indices[0]):
356
+ sim = -distances[0][i]
357
+ image_url = df.iloc[v]["Link"]
358
+ img_retrieved = read_image_from_url(image_url)
359
+ gallery_output.append(img_retrieved)
360
+ return gallery_output
361
+
362
+
363
+ ###############################################################################
364
+ ############# Above part is for sCLIP #############
365
+ ###############################################################################
366
+
367
+
368
  ###############################################################################
369
  ############# this part is for text to image #############
370
  ###############################################################################
 
684
 
685
 
686
  def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English"):
687
+ if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
688
+ image_input = image_input['background']
689
+
690
+ if isinstance(image_input, str):
691
+ image_input = Image.open(io.BytesIO(base64.b64decode(image_input)))
692
+ elif isinstance(image_input, bytes):
693
+ image_input = Image.open(io.BytesIO(image_input))
694
+
695
 
696
  click_state = [[], [], []]
697
  image_input = image_resize(image_input, res=1024)
 
1007
 
1008
  if trace_type=="Trace+Seg":
1009
  input_mask = np.array(out['mask'].convert('P'))
1010
+ image_input = mask_painter(np.array(image_input), input_mask, background_alpha=0)
1011
+ d3_input=mask_painter(np.array(image_input), input_mask)
1012
  crop_save_path=out['crop_save_path']
1013
 
1014
  else:
1015
  image_input = Image.fromarray(np.array(origin_image))
1016
  draw = ImageDraw.Draw(image_input)
1017
  draw.rectangle(boxes, outline='red', width=2)
1018
+ d3_input=image_input
1019
  cropped_image = origin_image.crop(boxes)
1020
  cropped_image.save('temp.png')
1021
  crop_save_path='temp.png'
 
1048
  try:
1049
  audio_output = await texttospeech(read_info, language,autoplay)
1050
  # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
1051
+ return state, state,image_input,audio_output,crop_save_path,d3_input
1052
 
1053
 
1054
  except Exception as e:
1055
  state = state + [(None, f"Error during TTS prediction: {str(e)}")]
1056
  print(f"Error during TTS prediction: {str(e)}")
1057
  # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
1058
+ return state, state, image_input,audio_output,crop_save_path,d3_input
1059
 
1060
 
1061
  else:
 
1361
 
1362
  with gr.Row():
1363
 
1364
+ with gr.Column(scale=6):
1365
  with gr.Column(visible=False) as modules_not_need_gpt:
1366
  with gr.Tab("Base(GPT Power)") as base_tab:
1367
  image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
 
1368
  with gr.Row():
1369
  name_label_base = gr.Button(value="Name: ")
1370
  artist_label_base = gr.Button(value="Artist: ")
 
1374
  with gr.Tab("Click") as click_tab:
1375
  image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1376
  example_image = gr.Image(type="pil", interactive=False, visible=False)
1377
+ # example_image_click = gr.Image(type="pil", interactive=False, visible=False)
1378
  with gr.Row():
1379
  name_label = gr.Button(value="Name: ")
1380
  artist_label = gr.Button(value="Artist: ")
1381
  year_label = gr.Button(value="Year: ")
1382
  material_label = gr.Button(value="Material: ")
1383
  with gr.Row():
1384
+ with gr.Column():
1385
+ with gr.Row():
1386
+ focus_type = gr.Radio(
1387
+ choices=["CFV-D", "CFV-DA", "CFV-DAI","PFV-DDA"],
1388
+ value="CFV-D",
1389
+ label="Information Type",
1390
+ interactive=True,
1391
+ scale=4)
1392
+
1393
+ with gr.Row():
1394
+ point_prompt = gr.Radio(
1395
+ choices=["Positive", "Negative"],
1396
+ value="Positive",
1397
+ label="Point Prompt",
1398
+ scale=5,
1399
  interactive=True)
1400
+ click_mode = gr.Radio(
1401
+ choices=["Continuous", "Single"],
1402
+ value="Continuous",
1403
+ label="Clicking Mode",
1404
+ scale=5,
1405
+ interactive=True)
1406
+ with gr.Column():
1407
+ with gr.Row():
1408
+ submit_button_click=gr.Button(value="Submit", interactive=True,variant='primary',scale=2)
1409
+ with gr.Row():
1410
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True,scale=2)
1411
+ clear_button_image = gr.Button(value="Clear Image", interactive=True,scale=2)
 
 
 
 
 
1412
 
1413
  with gr.Tab("Trajectory (beta)") as traj_tab:
1414
  # sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10,
1415
  # elem_id="image_sketcher")
1416
  sketcher_input = gr.ImageEditor(type="pil", interactive=True,
1417
  elem_id="image_sketcher")
1418
+ # example_image_traj = gr.Image(type="pil", interactive=False, visible=False)
1419
  with gr.Row():
 
1420
  clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True)
1421
+ submit_button_sketcher = gr.Button(value="Submit", interactive=True)
1422
  with gr.Row():
1423
  with gr.Row():
1424
  focus_type_sketch = gr.Radio(
 
1430
  choices=["Trace+Seg", "Trace"],
1431
  value="Trace+Seg",
1432
  label="Trace Type",
1433
+ interactive=True)
1434
 
1435
+ with gr.Column(visible=False,scale=4) as modules_need_gpt1:
1436
  with gr.Row():
1437
  sentiment = gr.Radio(
1438
  choices=["Positive", "Natural", "Negative"],
 
1471
 
1472
 
1473
 
1474
+ with gr.Column(scale=5):
1475
  with gr.Column(visible=True) as module_key_input:
1476
  openai_api_key = gr.Textbox(
1477
  placeholder="Input openAI API key",
 
1530
 
1531
  with gr.Column():
1532
  with gr.Column():
1533
+ gr.Radio(["Other Paintings by the Artist"], label="Artist", info="Who is the artist?🧑‍🎨"),
1534
  gr.Radio(["Oil Painting","Printmaking","Watercolor Painting","Drawing"], label="Art Forms", info="What are the art forms?🎨"),
1535
  gr.Radio(["Renaissance", "Baroque", "Impressionism","Modernism"], label="Period", info="Which art period?⏳"),
1536
  # to be done
 
1658
  # api_name="run",
1659
  # )
1660
  run_button.click(
1661
+ fn=infer,
1662
+ inputs=[new_crop_save_path],
1663
+ outputs=[result]
 
 
 
 
 
 
 
 
 
 
 
1664
  )
1665
 
1666
  ###############################################################################
 
1890
  [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1891
  image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1892
 
1893
+ # image_input.upload(upload_callback, [image_input, state, visual_chatgpt, openai_api_key],
1894
+ # [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1895
+ # image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1896
+ # sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt, openai_api_key],
1897
+ # [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
1898
+ # image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
1899
  chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
1900
  [chatbot, state, aux_state,output_audio])
1901
  chat_input.submit(lambda: "", None, chat_input)
 
1969
  origin_image,sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
1970
  original_size, input_size, text_refiner,focus_type_sketch,paragraph,openai_api_key,auto_play,Input_sketch
1971
  ],
1972
+ outputs=[chatbot, state, sketcher_input,output_audio,new_crop_save_path,input_image],
1973
  show_progress=False, queue=True
1974
  )
1975