Anonymous Authors commited on
Commit
3dd4093
·
1 Parent(s): e089949

Update app.py

Browse files

adding bar plots, making image names more readable

Files changed (1) hide show
  1. app.py +25 -9
app.py CHANGED
@@ -2,6 +2,8 @@ import json
2
  import gradio as gr
3
  import os
4
  from PIL import Image
 
 
5
 
6
  clusters_12 = json.load(open("clusters/id_all_blip_clusters_12.json"))
7
  clusters_24 = json.load(open("clusters/id_all_blip_clusters_24.json"))
@@ -18,12 +20,25 @@ def show_cluster(cl_id, num_clusters):
18
  images = []
19
  for i in range(6):
20
  img_path = "/".join([st.replace("/", "") for st in cl_dct['img_path_list'][i].split("//")][3:])
21
- images.append((Image.open(os.path.join("identities-images", img_path)), "_".join([img_path.split("/")[0], img_path.split("/")[-1]])))
 
 
 
 
 
 
 
 
 
 
 
 
22
  return (len(cl_dct['img_path_list']),
23
- dict(cl_dct["labels_gender"]),
24
- dict(cl_dct["labels_model"]),
25
- dict(cl_dct["labels_ethnicity"]),
26
- images)
 
27
 
28
  with gr.Blocks() as demo:
29
  gr.Markdown("# Cluster Explorer")
@@ -34,11 +49,12 @@ with gr.Blocks() as demo:
34
  button = gr.Button(value="Go")
35
  with gr.Column():
36
  a = gr.Text(label="Number of items in cluster")
 
37
  with gr.Row():
38
- c = gr.Text(label="Model makeup of cluster")
39
- b = gr.Text(label="Gender label makeup of cluster")
40
- d = gr.Text(label="Ethnicity label makeup of cluster")
41
- gallery = gr.Gallery(label="Most representative images in cluster").style(grid=6)
42
  button.click(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a,b,c,d, gallery])
43
  # demo = gr.Interface(fn=show_cluster, inputs=[gr.Slider(0, 50), gr.Radio([12, 24, 48])], outputs=["text", "text", "text", "text", gr.Gallery()])
44
  demo.launch(debug=True)
 
2
  import gradio as gr
3
  import os
4
  from PIL import Image
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
 
8
  clusters_12 = json.load(open("clusters/id_all_blip_clusters_12.json"))
9
  clusters_24 = json.load(open("clusters/id_all_blip_clusters_24.json"))
 
20
  images = []
21
  for i in range(6):
22
  img_path = "/".join([st.replace("/", "") for st in cl_dct['img_path_list'][i].split("//")][3:])
23
+ images.append((Image.open(os.path.join("identities-images", img_path)), "_".join([img_path.split("/")[0], img_path.split("/")[-1]]).replace('Photo_portrait_of_an_','').replace('Photo_portrait_of_a_','').replace('SD_v2_random_seeds_identity_','(SD v.2) ').replace('dataset-identities-dalle2_','(Dall-E 2) ').replace('SD_v1.4_random_seeds_identity_','(SD v.1.4) ').replace('_',' ')))
24
+ model_fig = go.Figure()
25
+ model_fig.add_trace(go.Bar(x=list(dict(cl_dct["labels_model"]).keys()),
26
+ y=list(dict(cl_dct["labels_model"]).values()),
27
+ marker_color=px.colors.qualitative.G10))
28
+ gender_fig = go.Figure()
29
+ gender_fig.add_trace(go.Bar(x=list(dict(cl_dct["labels_gender"]).keys()),
30
+ y=list(dict(cl_dct["labels_gender"]).values()),
31
+ marker_color=px.colors.qualitative.G10))
32
+ ethnicity_fig = go.Figure()
33
+ ethnicity_fig.add_trace(go.Bar(x=list(dict(cl_dct["labels_ethnicity"]).keys()),
34
+ y=list(dict(cl_dct["labels_ethnicity"]).values()),
35
+ marker_color=px.colors.qualitative.G10))
36
  return (len(cl_dct['img_path_list']),
37
+ gender_fig,
38
+ #dict(cl_dct["labels_model"]),
39
+ model_fig,
40
+ ethnicity_fig,
41
+ images)
42
 
43
  with gr.Blocks() as demo:
44
  gr.Markdown("# Cluster Explorer")
 
49
  button = gr.Button(value="Go")
50
  with gr.Column():
51
  a = gr.Text(label="Number of items in cluster")
52
+ gallery = gr.Gallery(label="Most representative images in cluster").style(grid=6)
53
  with gr.Row():
54
+ c = gr.Plot(label="Model makeup of cluster")
55
+ b = gr.Plot(label="Gender label makeup of cluster")
56
+ d = gr.Plot(label="Ethnicity label makeup of cluster")
57
+
58
  button.click(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a,b,c,d, gallery])
59
  # demo = gr.Interface(fn=show_cluster, inputs=[gr.Slider(0, 50), gr.Radio([12, 24, 48])], outputs=["text", "text", "text", "text", gr.Gallery()])
60
  demo.launch(debug=True)