HComP-Net / app.py
harishB97's picture
Update app.py
d50a100 verified
raw
history blame
6.24 kB
# import gradio as gr
# import time
# def update_options(selected_option):
# new_options = [selected_option + str(x) for x in range(1, 4)]
# return gr.Dropdown(choices=new_options, interactive=True)
# with gr.Blocks() as interface:
# dropdown = gr.Dropdown(choices=["Anna", "Christine", "Phoebe"])
# dropdown2 = gr.Dropdown(choices=["Anna", "Christine", "Phoebe"], allow_custom_value=True)
# dropdown.change(update_options, dropdown, dropdown2)
# interface.launch()
import gradio as gr
import plotly.graph_objects as go
import os
from collections import defaultdict
# print(os.pwd())
species_to_imgpath = {'bird': './descendent_specific_topk=10_heatmap_withbb_ep=last_024+051'}
# this has to be there for each species
imgname_to_filepath = {} # this ignores the extension such as .png
nodename_to_protoIDs = defaultdict()
for species, imgpath in species_to_imgpath.items():
for foldername in os.listdir(imgpath):
if os.path.isdir(os.path.join(imgpath, foldername)):
folderpath = os.path.join(imgpath, foldername)
for filename in os.listdir(folderpath):
if filename.endswith('png') or filename.endswith('jpg'):
filepath = os.path.join(folderpath, filename)
imgname_to_filepath[filename] = filepath
nodename = filename.split('.')[0].split('-')[0]
print(filename, filepath, filepath.split('.'), filepath.split('.')[0].split('-'))
protoID = filename.split('.')[0].split('-')[1]
nodename_to_protoIDs[nodename].append(protoID)
def display_tree():
# This function should create and return a Plotly figure of the tree
# Currently returns a simple string, but should be replaced with actual graph
# Define the nodes and edges for the graph
nodes = ['Node 1', 'Node 2', 'Node 3', 'Node 4']
edges = [(0, 1), (0, 2), (2, 3)] # Edges are tuples of node indices
# Define positions for the nodes (you can use a layout algorithm for more complex graphs)
positions = [(0, 0), (1, 2), (1, -2), (2, 0)]
# Create traces for nodes and edges
edge_x = []
edge_y = []
for edge in edges:
x0, y0 = positions[edge[0]]
x1, y1 = positions[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='Black'),
hoverinfo='none',
mode='lines')
node_x = [pos[0] for pos in positions]
node_y = [pos[1] for pos in positions]
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
marker=dict(showscale=False, size=10, color='Goldenrod'),
text=nodes,
textposition="top center"
)
# Define the layout of the graph
layout = go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0, l=0, r=0, t=0),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
)
# Create the figure
fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
return fig
def display_image_based_on_dropdown_1(dropdown_value):
# Create a white image
img = Image.new('RGB', (200, 100), color='white')
d = ImageDraw.Draw(img)
# Specify a font. If you have a .ttf font file you can specify its path
# fnt = ImageFont.truetype('/path/to/font.ttf', 40)
# Otherwise, we'll use a default PIL font
fnt = ImageFont.load_default()
# Position the text in the center
text = "Placeholder"
textwidth, textheight = d.textsize(text, font=fnt)
width, height = img.size
x = (width - textwidth) / 2
y = (height - textheight) / 2
# Draw the text onto the image
d.text((x,y), text, font=fnt, fill='black')
# Save the image to a file in buffer to return
img.save('/tmp/dummy_image.png')
return '/tmp/dummy_image.png'
def display_image_based_on_dropdown_2(dropdown_value):
# Create a white image
img = Image.new('RGB', (200, 100), color='white')
d = ImageDraw.Draw(img)
# Specify a font. If you have a .ttf font file you can specify its path
# fnt = ImageFont.truetype('/path/to/font.ttf', 40)
# Otherwise, we'll use a default PIL font
fnt = ImageFont.load_default()
# Position the text in the center
text = "Placeholder"
textwidth, textheight = d.textsize(text, font=fnt)
width, height = img.size
x = (width - textwidth) / 2
y = (height - textheight) / 2
# Draw the text onto the image
d.text((x,y), text, font=fnt, fill='black')
# Save the image to a file in buffer to return
img.save('/tmp/dummy_image.png')
return '/tmp/dummy_image.png'
def get_protoIDs(nodename):
return gr.Dropdown(choices=nodename_to_protoIDs[nodename], interactive=True)
with gr.Blocks() as demo:
gr.Markdown("## Interactive Tree and Image Display")
with gr.Row():
tree_output = gr.Plot(display_tree) # Connect the function directly
with gr.Row():
with gr.Column():
dropdown_1_nodename = gr.Dropdown(label="Select a node name", choices=list(nodename_to_protoIDs.keys()))
dropdown_1_protos = gr.Dropdown(label="Select a prototype ID", choices=[], allow_custom_value=True)
image_output_1 = gr.Image('new_teaser (3)-1.png')
with gr.Column():
dropdown_2_nodename = gr.Dropdown(label="Select a node name", choices=list(nodename_to_protoIDs.keys()))
dropdown_2_protos = gr.Dropdown(label="Select a prototype ID", choices=[], allow_custom_value=True)
image_output_2 = gr.Image('new_teaser (3)-1.png')
dropdown_1_nodename.change(get_protoIDs, dropdown_1_nodename, dropdown_1_protos)
dropdown_2_nodename.change(get_protoIDs, dropdown_2_nodename, dropdown_2_protos)
# Initialize with placeholder images
# image_output_1.update(display_image_based_on_dropdown_1)
# image_output_2.update(display_image_based_on_dropdown_2)
demo.launch()