Niki Zhang commited on
Commit
acb115a
·
verified ·
1 Parent(s): b03c1aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -90
app.py CHANGED
@@ -347,74 +347,74 @@ def extract_features_siglip(image):
347
  return image_features
348
 
349
  @spaces.GPU
350
- def infer(crop_image_path,full_image_path,state,language,session_type,task_type=None):
351
  print("task type",task_type)
352
  gallery_output = []
353
- if session_type=="Session 1":
354
- if task_type=="task 1":
355
- gallery_output.append("recomendation_pic/1.8.jpg")
356
- gallery_output.append("recomendation_pic/1.9.jpg")
357
- input_image = Image.open(full_image_path).convert("RGB")
358
- input_features = extract_features_siglip(input_image.convert("RGB"))
359
- input_features = input_features.detach().cpu().numpy()
360
- input_features = np.float32(input_features)
361
- faiss.normalize_L2(input_features)
362
- distances, indices = index.search(input_features, 2)
363
- for i,v in enumerate(indices[0]):
364
- sim = -distances[0][i]
365
- image_url = df.iloc[v]["Link"]
366
- img_retrieved = read_image_from_url(image_url)
367
- gallery_output.append(img_retrieved)
368
- if language=="English":
369
- msg="🖼️ Please refer to the section below to see the recommended results."
370
- else:
371
- msg="🖼️ 请到下方查看推荐结果。"
372
- state+=[(None,msg)]
373
-
374
- return gallery_output,state,state
375
- elif task_type=="task 2":
376
- gallery_output.append("recomendation_pic/2.8.jpg")
377
- gallery_output.append("recomendation_pic/2.9.png")
378
- input_image = Image.open(full_image_path).convert("RGB")
379
- input_features = extract_features_siglip(input_image.convert("RGB"))
380
- input_features = input_features.detach().cpu().numpy()
381
- input_features = np.float32(input_features)
382
- faiss.normalize_L2(input_features)
383
- distances, indices = index.search(input_features, 2)
384
- for i,v in enumerate(indices[0]):
385
- sim = -distances[0][i]
386
- image_url = df.iloc[v]["Link"]
387
- img_retrieved = read_image_from_url(image_url)
388
- gallery_output.append(img_retrieved)
389
- if language=="English":
390
- msg="🖼️ Please refer to the section below to see the recommended results."
391
- else:
392
- msg="🖼️ 请到下方查看推荐结果。"
393
- state+=[(None,msg)]
394
-
395
- return gallery_output,state,state
396
-
397
- elif task_type=="task 3":
398
- gallery_output.append("recomendation_pic/3.8.png")
399
- gallery_output.append("recomendation_pic/3.9.png")
400
- input_image = Image.open(full_image_path).convert("RGB")
401
- input_features = extract_features_siglip(input_image.convert("RGB"))
402
- input_features = input_features.detach().cpu().numpy()
403
- input_features = np.float32(input_features)
404
- faiss.normalize_L2(input_features)
405
- distances, indices = index.search(input_features, 2)
406
- for i,v in enumerate(indices[0]):
407
- sim = -distances[0][i]
408
- image_url = df.iloc[v]["Link"]
409
- img_retrieved = read_image_from_url(image_url)
410
- gallery_output.append(img_retrieved)
411
- if language=="English":
412
- msg="🖼️ Please refer to the section below to see the recommended results."
413
- else:
414
- msg="🖼️ 请到下方查看推荐结果。"
415
- state+=[(None,msg)]
416
-
417
- return gallery_output,state,state
418
 
419
  elif crop_image_path:
420
  input_image = Image.open(crop_image_path).convert("RGB")
@@ -1090,7 +1090,7 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
1090
  Image.open(out["crop_save_path"]).save(new_crop_save_path)
1091
  print("new crop save",new_crop_save_path)
1092
 
1093
- yield state, state, click_state, image_input_nobackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground
1094
 
1095
 
1096
  query_focus_en = [
@@ -1646,7 +1646,7 @@ async def texttospeech(text, language,gender='female'):
1646
  return None
1647
 
1648
  # give the reason of recommendation
1649
- async def associate(image_path,new_crop,openai_api_key,language,autoplay,length,log_state,sort_score,narritive,evt: gr.SelectData):
1650
  persona=naritive_mapping[narritive]
1651
  rec_path=evt._data['value']['image']['path']
1652
  index=evt.index
@@ -1658,7 +1658,7 @@ async def associate(image_path,new_crop,openai_api_key,language,autoplay,length,
1658
  image_paths=[image_path,rec_path]
1659
  result=get_gpt_response(openai_api_key, image_paths, prompt)
1660
  print("recommend result",result)
1661
- reason = [(None, f"{result}")]
1662
  log_state = log_state + [(narritive, None)]
1663
  log_state = log_state + [(f"image sort ranking {sort_score}", None)]
1664
  log_state = log_state + [(None, f"{result}")]
@@ -1668,11 +1668,11 @@ async def associate(image_path,new_crop,openai_api_key,language,autoplay,length,
1668
  audio_output=None
1669
  if autoplay:
1670
  audio_output = await texttospeech(read_info, language)
1671
- return reason,audio_output,log_state,index,gr.update(value=[])
1672
 
1673
- def change_naritive(session_type,image_input, chatbot, state, click_state, paragraph, origin_image,narritive,language="English"):
1674
  if session_type=="Session 1":
1675
- return None, [], [], [[], [], []], "", None, []
1676
  else:
1677
  if language=="English":
1678
  if narritive=="Third-person" :
@@ -1720,7 +1720,7 @@ def change_naritive(session_type,image_input, chatbot, state, click_state, parag
1720
  ]
1721
 
1722
 
1723
- return image_input, state, state, click_state, paragraph, origin_image
1724
 
1725
 
1726
  def print_like_dislike(x: gr.LikeData,state,log_state):
@@ -1766,7 +1766,7 @@ def create_ui():
1766
  examples = [
1767
  ["test_images/1.The Ambassadors.jpg","test_images/task1.jpg","task 1"],
1768
  ["test_images/2.Football Players.jpg","test_images/task2.jpg","task 2"],
1769
- ["test_images/3.Along the River during the Qingming Festival.jpeg","test_images/task3.jpg","task 3"],
1770
  # ["test_images/test4.jpg"],
1771
  # ["test_images/test5.jpg"],
1772
  # ["test_images/Picture5.png"],
@@ -1810,6 +1810,7 @@ def create_ui():
1810
  # store the whole image path
1811
  image_path=gr.State('')
1812
  pic_index=gr.State(None)
 
1813
 
1814
 
1815
  with gr.Row():
@@ -1821,8 +1822,7 @@ def create_ui():
1821
  )
1822
  with gr.Row():
1823
  with gr.Column(scale=1,min_width=50,visible=False) as instruct:
1824
- task_instuction=gr.Image(type="pil", interactive=True, elem_classes="task_instruct",height=650,label=None)
1825
-
1826
  with gr.Column(scale=6):
1827
  with gr.Column(visible=False) as modules_not_need_gpt:
1828
 
@@ -1941,6 +1941,7 @@ def create_ui():
1941
  with gr.Column(scale=4):
1942
  with gr.Column(visible=True) as module_key_input:
1943
  openai_api_key = gr.Textbox(
 
1944
  placeholder="Input openAI API key",
1945
  show_label=False,
1946
  label="OpenAI API Key",
@@ -2206,14 +2207,14 @@ def create_ui():
2206
  # )
2207
  recommend_btn.click(
2208
  fn=infer,
2209
- inputs=[new_crop_save_path,image_path,state,language,session_type,task_type],
2210
  outputs=[gallery_result,chatbot,state]
2211
  )
2212
 
2213
  gallery_result.select(
2214
  associate,
2215
- inputs=[image_path,new_crop_save_path,openai_api_key,language,auto_play,length,log_state,sort_rec,naritive],
2216
- outputs=[recommend_bot,output_audio,log_state,pic_index,recommend_score],
2217
 
2218
 
2219
  )
@@ -2434,11 +2435,18 @@ def create_ui():
2434
 
2435
  # cap_everything_button.click(cap_everything, [paragraph, visual_chatgpt, language,auto_play],
2436
  # [paragraph_output,output_audio])
2437
-
 
 
 
 
 
 
 
2438
  clear_button_click.click(
2439
- lambda x: ([[], [], []], x),
2440
  [origin_image],
2441
- [click_state, image_input],
2442
  queue=False,
2443
  show_progress=False
2444
  )
@@ -2525,10 +2533,10 @@ def create_ui():
2525
  paragraph,artist,gender,image_path, log_state,history_log,output_audio])
2526
 
2527
  example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
2528
- example_image.change(
2529
- lambda:([],[]),
2530
- [],
2531
- [gallery_result,recommend_bot])
2532
 
2533
  # def on_click_tab_selected():
2534
  # if gpt_state ==1:
@@ -2672,20 +2680,21 @@ def create_ui():
2672
 
2673
  naritive.change(
2674
  change_naritive,
2675
- [session_type, image_input, chatbot, state, click_state, paragraph, origin_image,naritive,language],
2676
- [image_input, chatbot, state, click_state, paragraph, origin_image,gallery_result],
 
2677
  queue=False,
2678
  show_progress=False
2679
 
2680
  )
2681
  def session_change():
2682
  instruction=Image.open('test_images/task4.jpg')
2683
- return None, [], [], [[], [], []], "", None, [],[],instruction
2684
 
2685
  session_type.change(
2686
  session_change,
2687
  [],
2688
- [image_input, chatbot, state, click_state, paragraph, origin_image,history_log,log_state,task_instuction]
2689
  )
2690
 
2691
  # upvote_btn.click(
 
347
  return image_features
348
 
349
  @spaces.GPU
350
+ def infer(crop_image_path,full_image_path,state,language,task_type=None):
351
  print("task type",task_type)
352
  gallery_output = []
353
+
354
+ if task_type=="task 1":
355
+ gallery_output.append("recomendation_pic/1.8.jpg")
356
+ gallery_output.append("recomendation_pic/1.9.jpg")
357
+ input_image = Image.open(full_image_path).convert("RGB")
358
+ input_features = extract_features_siglip(input_image.convert("RGB"))
359
+ input_features = input_features.detach().cpu().numpy()
360
+ input_features = np.float32(input_features)
361
+ faiss.normalize_L2(input_features)
362
+ distances, indices = index.search(input_features, 2)
363
+ for i,v in enumerate(indices[0]):
364
+ sim = -distances[0][i]
365
+ image_url = df.iloc[v]["Link"]
366
+ img_retrieved = read_image_from_url(image_url)
367
+ gallery_output.append(img_retrieved)
368
+ if language=="English":
369
+ msg="🖼️ Please refer to the section below to see the recommended results."
370
+ else:
371
+ msg="🖼️ 请到下方查看推荐结果。"
372
+ state+=[(None,msg)]
373
+
374
+ return gallery_output,state,state
375
+ elif task_type=="task 2":
376
+ gallery_output.append("recomendation_pic/2.8.jpg")
377
+ gallery_output.append("recomendation_pic/2.9.png")
378
+ input_image = Image.open(full_image_path).convert("RGB")
379
+ input_features = extract_features_siglip(input_image.convert("RGB"))
380
+ input_features = input_features.detach().cpu().numpy()
381
+ input_features = np.float32(input_features)
382
+ faiss.normalize_L2(input_features)
383
+ distances, indices = index.search(input_features, 2)
384
+ for i,v in enumerate(indices[0]):
385
+ sim = -distances[0][i]
386
+ image_url = df.iloc[v]["Link"]
387
+ img_retrieved = read_image_from_url(image_url)
388
+ gallery_output.append(img_retrieved)
389
+ if language=="English":
390
+ msg="🖼️ Please refer to the section below to see the recommended results."
391
+ else:
392
+ msg="🖼️ 请到下方查看推荐结果。"
393
+ state+=[(None,msg)]
394
+
395
+ return gallery_output,state,state
396
+
397
+ elif task_type=="task 3":
398
+ gallery_output.append("recomendation_pic/3.8.png")
399
+ gallery_output.append("recomendation_pic/3.9.png")
400
+ input_image = Image.open(full_image_path).convert("RGB")
401
+ input_features = extract_features_siglip(input_image.convert("RGB"))
402
+ input_features = input_features.detach().cpu().numpy()
403
+ input_features = np.float32(input_features)
404
+ faiss.normalize_L2(input_features)
405
+ distances, indices = index.search(input_features, 2)
406
+ for i,v in enumerate(indices[0]):
407
+ sim = -distances[0][i]
408
+ image_url = df.iloc[v]["Link"]
409
+ img_retrieved = read_image_from_url(image_url)
410
+ gallery_output.append(img_retrieved)
411
+ if language=="English":
412
+ msg="🖼️ Please refer to the section below to see the recommended results."
413
+ else:
414
+ msg="🖼️ 请到下方查看推荐结果。"
415
+ state+=[(None,msg)]
416
+
417
+ return gallery_output,state,state
418
 
419
  elif crop_image_path:
420
  input_image = Image.open(crop_image_path).convert("RGB")
 
1090
  Image.open(out["crop_save_path"]).save(new_crop_save_path)
1091
  print("new crop save",new_crop_save_path)
1092
 
1093
+ return state, state, click_state, image_input_nobackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground
1094
 
1095
 
1096
  query_focus_en = [
 
1646
  return None
1647
 
1648
  # give the reason of recommendation
1649
+ async def associate(image_path,new_crop,openai_api_key,language,autoplay,length,log_state,sort_score,narritive,state,evt: gr.SelectData):
1650
  persona=naritive_mapping[narritive]
1651
  rec_path=evt._data['value']['image']['path']
1652
  index=evt.index
 
1658
  image_paths=[image_path,rec_path]
1659
  result=get_gpt_response(openai_api_key, image_paths, prompt)
1660
  print("recommend result",result)
1661
+ state += [(None, f"{result}")]
1662
  log_state = log_state + [(narritive, None)]
1663
  log_state = log_state + [(f"image sort ranking {sort_score}", None)]
1664
  log_state = log_state + [(None, f"{result}")]
 
1668
  audio_output=None
1669
  if autoplay:
1670
  audio_output = await texttospeech(read_info, language)
1671
+ return state,state,audio_output,log_state,index,gr.update(value=[])
1672
 
1673
+ def change_naritive(session_type,image_input, state, click_state, paragraph, origin_image,narritive,task_instruct,gallery_output,reco_reasons,language="English"):
1674
  if session_type=="Session 1":
1675
+ return None, [], [], [[], [], []], "", None, None, [], [],[]
1676
  else:
1677
  if language=="English":
1678
  if narritive=="Third-person" :
 
1720
  ]
1721
 
1722
 
1723
+ return image_input, state, state, click_state, paragraph, origin_image,task_instruct,gallery_output,reco_reasons,reco_reasons
1724
 
1725
 
1726
  def print_like_dislike(x: gr.LikeData,state,log_state):
 
1766
  examples = [
1767
  ["test_images/1.The Ambassadors.jpg","test_images/task1.jpg","task 1"],
1768
  ["test_images/2.Football Players.jpg","test_images/task2.jpg","task 2"],
1769
+ ["test_images/3-square.jpg","test_images/task3.jpg","task 3"],
1770
  # ["test_images/test4.jpg"],
1771
  # ["test_images/test5.jpg"],
1772
  # ["test_images/Picture5.png"],
 
1810
  # store the whole image path
1811
  image_path=gr.State('')
1812
  pic_index=gr.State(None)
1813
+ recomended_state=gr.State([])
1814
 
1815
 
1816
  with gr.Row():
 
1822
  )
1823
  with gr.Row():
1824
  with gr.Column(scale=1,min_width=50,visible=False) as instruct:
1825
+ task_instuction=gr.Image(type="pil", interactive=True, elem_classes="task_instruct",height=650,label=None)
 
1826
  with gr.Column(scale=6):
1827
  with gr.Column(visible=False) as modules_not_need_gpt:
1828
 
 
1941
  with gr.Column(scale=4):
1942
  with gr.Column(visible=True) as module_key_input:
1943
  openai_api_key = gr.Textbox(
1944
+ value="sk-proj-bxHhgjZV8TVgd1IupZrUT3BlbkFJvrthq6zIxpZVk3vwsvJ9",
1945
  placeholder="Input openAI API key",
1946
  show_label=False,
1947
  label="OpenAI API Key",
 
2207
  # )
2208
  recommend_btn.click(
2209
  fn=infer,
2210
+ inputs=[new_crop_save_path,image_path,state,language,task_type],
2211
  outputs=[gallery_result,chatbot,state]
2212
  )
2213
 
2214
  gallery_result.select(
2215
  associate,
2216
+ inputs=[image_path,new_crop_save_path,openai_api_key,language,auto_play,length,log_state,sort_rec,naritive,recomended_state],
2217
+ outputs=[recommend_bot,recomended_state,output_audio,log_state,pic_index,recommend_score],
2218
 
2219
 
2220
  )
 
2435
 
2436
  # cap_everything_button.click(cap_everything, [paragraph, visual_chatgpt, language,auto_play],
2437
  # [paragraph_output,output_audio])
2438
+ def reset_and_add(origin_image):
2439
+ new_prompt = "Positive"
2440
+ new_add_icon = "assets/icons/plus-square-blue.png"
2441
+ new_add_css = "tools_button_clicked"
2442
+ new_minus_icon = "assets/icons/minus-square.png"
2443
+ new_minus_css= "tools_button"
2444
+ return [[],[],[]],origin_image, new_prompt, gr.update(icon=new_add_icon,elem_classes=new_add_css), gr.update(icon=new_minus_icon,elem_classes=new_minus_css)
2445
+
2446
  clear_button_click.click(
2447
+ reset_and_add,
2448
  [origin_image],
2449
+ [click_state, image_input,point_prompt,add_button,minus_button],
2450
  queue=False,
2451
  show_progress=False
2452
  )
 
2533
  paragraph,artist,gender,image_path, log_state,history_log,output_audio])
2534
 
2535
  example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
2536
+ # example_image.change(
2537
+ # lambda:([],[]),
2538
+ # [],
2539
+ # [gallery_result,recommend_bot])
2540
 
2541
  # def on_click_tab_selected():
2542
  # if gpt_state ==1:
 
2680
 
2681
  naritive.change(
2682
  change_naritive,
2683
+ [session_type, image_input, state, click_state, paragraph, origin_image,naritive,
2684
+ task_instuction,gallery_result,recomended_state,language],
2685
+ [image_input, chatbot, state, click_state, paragraph, origin_image,task_instuction,gallery_result,recomended_state,recommend_bot],
2686
  queue=False,
2687
  show_progress=False
2688
 
2689
  )
2690
  def session_change():
2691
  instruction=Image.open('test_images/task4.jpg')
2692
+ return None, [], [], [[], [], []], "", None, [],[],instruction,"task 4"
2693
 
2694
  session_type.change(
2695
  session_change,
2696
  [],
2697
+ [image_input, chatbot, state, click_state, paragraph, origin_image,history_log,log_state,task_instuction,task_type]
2698
  )
2699
 
2700
  # upvote_btn.click(