File size: 6,238 Bytes
f1853ff 525a6df f1853ff 96ada74 525a6df f1853ff 525a6df e5411ff 525a6df e8af752 525a6df 9a99cab 129ce22 525a6df 8254f87 525a6df d50a100 0afa34d d50a100 525a6df f1853ff 525a6df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# import gradio as gr
# import time
# def update_options(selected_option):
# new_options = [selected_option + str(x) for x in range(1, 4)]
# return gr.Dropdown(choices=new_options, interactive=True)
# with gr.Blocks() as interface:
# dropdown = gr.Dropdown(choices=["Anna", "Christine", "Phoebe"])
# dropdown2 = gr.Dropdown(choices=["Anna", "Christine", "Phoebe"], allow_custom_value=True)
# dropdown.change(update_options, dropdown, dropdown2)
# interface.launch()
import gradio as gr
import plotly.graph_objects as go
import os
from collections import defaultdict
# print(os.pwd())
species_to_imgpath = {'bird': './descendent_specific_topk=10_heatmap_withbb_ep=last_024+051'}
# this has to be there for each species
imgname_to_filepath = {} # this ignores the extension such as .png
nodename_to_protoIDs = defaultdict()
for species, imgpath in species_to_imgpath.items():
for foldername in os.listdir(imgpath):
if os.path.isdir(os.path.join(imgpath, foldername)):
folderpath = os.path.join(imgpath, foldername)
for filename in os.listdir(folderpath):
if filename.endswith('png') or filename.endswith('jpg'):
filepath = os.path.join(folderpath, filename)
imgname_to_filepath[filename] = filepath
nodename = filename.split('.')[0].split('-')[0]
print(filename, filepath, filepath.split('.'), filepath.split('.')[0].split('-'))
protoID = filename.split('.')[0].split('-')[1]
nodename_to_protoIDs[nodename].append(protoID)
def display_tree():
# This function should create and return a Plotly figure of the tree
# Currently returns a simple string, but should be replaced with actual graph
# Define the nodes and edges for the graph
nodes = ['Node 1', 'Node 2', 'Node 3', 'Node 4']
edges = [(0, 1), (0, 2), (2, 3)] # Edges are tuples of node indices
# Define positions for the nodes (you can use a layout algorithm for more complex graphs)
positions = [(0, 0), (1, 2), (1, -2), (2, 0)]
# Create traces for nodes and edges
edge_x = []
edge_y = []
for edge in edges:
x0, y0 = positions[edge[0]]
x1, y1 = positions[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='Black'),
hoverinfo='none',
mode='lines')
node_x = [pos[0] for pos in positions]
node_y = [pos[1] for pos in positions]
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
marker=dict(showscale=False, size=10, color='Goldenrod'),
text=nodes,
textposition="top center"
)
# Define the layout of the graph
layout = go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0, l=0, r=0, t=0),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
)
# Create the figure
fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
return fig
def display_image_based_on_dropdown_1(dropdown_value):
# Create a white image
img = Image.new('RGB', (200, 100), color='white')
d = ImageDraw.Draw(img)
# Specify a font. If you have a .ttf font file you can specify its path
# fnt = ImageFont.truetype('/path/to/font.ttf', 40)
# Otherwise, we'll use a default PIL font
fnt = ImageFont.load_default()
# Position the text in the center
text = "Placeholder"
textwidth, textheight = d.textsize(text, font=fnt)
width, height = img.size
x = (width - textwidth) / 2
y = (height - textheight) / 2
# Draw the text onto the image
d.text((x,y), text, font=fnt, fill='black')
# Save the image to a file in buffer to return
img.save('/tmp/dummy_image.png')
return '/tmp/dummy_image.png'
def display_image_based_on_dropdown_2(dropdown_value):
# Create a white image
img = Image.new('RGB', (200, 100), color='white')
d = ImageDraw.Draw(img)
# Specify a font. If you have a .ttf font file you can specify its path
# fnt = ImageFont.truetype('/path/to/font.ttf', 40)
# Otherwise, we'll use a default PIL font
fnt = ImageFont.load_default()
# Position the text in the center
text = "Placeholder"
textwidth, textheight = d.textsize(text, font=fnt)
width, height = img.size
x = (width - textwidth) / 2
y = (height - textheight) / 2
# Draw the text onto the image
d.text((x,y), text, font=fnt, fill='black')
# Save the image to a file in buffer to return
img.save('/tmp/dummy_image.png')
return '/tmp/dummy_image.png'
def get_protoIDs(nodename):
return gr.Dropdown(choices=nodename_to_protoIDs[nodename], interactive=True)
with gr.Blocks() as demo:
gr.Markdown("## Interactive Tree and Image Display")
with gr.Row():
tree_output = gr.Plot(display_tree) # Connect the function directly
with gr.Row():
with gr.Column():
dropdown_1_nodename = gr.Dropdown(label="Select a node name", choices=list(nodename_to_protoIDs.keys()))
dropdown_1_protos = gr.Dropdown(label="Select a prototype ID", choices=[], allow_custom_value=True)
image_output_1 = gr.Image('new_teaser (3)-1.png')
with gr.Column():
dropdown_2_nodename = gr.Dropdown(label="Select a node name", choices=list(nodename_to_protoIDs.keys()))
dropdown_2_protos = gr.Dropdown(label="Select a prototype ID", choices=[], allow_custom_value=True)
image_output_2 = gr.Image('new_teaser (3)-1.png')
dropdown_1_nodename.change(get_protoIDs, dropdown_1_nodename, dropdown_1_protos)
dropdown_2_nodename.change(get_protoIDs, dropdown_2_nodename, dropdown_2_protos)
# Initialize with placeholder images
# image_output_1.update(display_image_based_on_dropdown_1)
# image_output_2.update(display_image_based_on_dropdown_2)
demo.launch()
|