File size: 4,399 Bytes
525a6df
 
 
 
 
9a99cab
129ce22
 
525a6df
 
 
01c40ff
525a6df
 
 
8254f87
525a6df
 
 
 
 
d50a100
 
525a6df
 
f1853ff
525a6df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4541211
 
 
 
 
 
 
525a6df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4541211
525a6df
51be45a
 
525a6df
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import plotly.graph_objects as go
import os
from collections import defaultdict

# print(os.pwd())

species_to_imgpath = {'bird': './descendent_specific_topk=10_heatmap_withbb_ep=last_024+051'}

# this has to be there for each species
imgname_to_filepath = {} # this ignores the extension such as .png
nodename_to_protoIDs = defaultdict(list)

for species, imgpath in species_to_imgpath.items():
    for foldername in os.listdir(imgpath):
        if os.path.isdir(os.path.join(imgpath, foldername)):
            folderpath = os.path.join(imgpath, foldername)
            for filename in os.listdir(folderpath):
                if filename.endswith('png') or filename.endswith('jpg'):
                    filepath = os.path.join(folderpath, filename)
                    imgname_to_filepath[filename] = filepath
                    nodename = filename.split('.')[0].split('-')[0]
                    protoID = filename.split('.')[0].split('-')[1]
                    nodename_to_protoIDs[nodename].append(protoID)
                    

def display_tree():
    # This function should create and return a Plotly figure of the tree
    # Currently returns a simple string, but should be replaced with actual graph

    # Define the nodes and edges for the graph
    nodes = ['Node 1', 'Node 2', 'Node 3', 'Node 4']
    edges = [(0, 1), (0, 2), (2, 3)]  # Edges are tuples of node indices
    
    # Define positions for the nodes (you can use a layout algorithm for more complex graphs)
    positions = [(0, 0), (1, 2), (1, -2), (2, 0)]
    
    # Create traces for nodes and edges
    edge_x = []
    edge_y = []
    for edge in edges:
        x0, y0 = positions[edge[0]]
        x1, y1 = positions[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
    
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=2, color='Black'),
        hoverinfo='none',
        mode='lines')
    
    node_x = [pos[0] for pos in positions]
    node_y = [pos[1] for pos in positions]
    
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        hoverinfo='text',
        marker=dict(showscale=False, size=10, color='Goldenrod'),
        text=nodes,
        textposition="top center"
    )
    
    # Define the layout of the graph
    layout = go.Layout(
        showlegend=False,
        hovermode='closest',
        margin=dict(b=0, l=0, r=0, t=0),
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
    )
    
    # Create the figure
    fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
    return fig



def get_protoIDs(nodename):
    return gr.Dropdown(choices=nodename_to_protoIDs[nodename], interactive=True)


def get_image(nodename, protoID):
    imgname = '-'.join([nodename, protoID]) + '.png'
    filepath = imgname_to_filepath[imgname]
    return gr.Image(filepath)
    

with gr.Blocks() as demo:
    gr.Markdown("## Interactive Tree and Image Display")
    
    with gr.Row():
        tree_output = gr.Plot(display_tree)  # Connect the function directly
    
    with gr.Row():
        with gr.Column():
            dropdown_1_nodename = gr.Dropdown(label="Select a node name", choices=list(nodename_to_protoIDs.keys()))
            dropdown_1_protos = gr.Dropdown(label="Select a prototype ID", choices=[], allow_custom_value=True)
            image_output_1 = gr.Image('new_teaser (3)-1.png')
        with gr.Column():
            dropdown_2_nodename = gr.Dropdown(label="Select a node name", choices=list(nodename_to_protoIDs.keys()))
            dropdown_2_protos = gr.Dropdown(label="Select a prototype ID", choices=[], allow_custom_value=True)
            image_output_2 = gr.Image('new_teaser (3)-1.png')

        dropdown_1_nodename.change(get_protoIDs, dropdown_1_nodename, dropdown_1_protos)
        dropdown_1_protos.change(get_image, [dropdown_1_nodename, dropdown_1_protos], image_output_1)
        dropdown_2_nodename.change(get_protoIDs, dropdown_2_nodename, dropdown_2_protos)
        dropdown_2_protos.change(get_image, [dropdown_2_nodename, dropdown_2_protos], image_output_2)
        

# Initialize with placeholder images
# image_output_1.update(display_image_based_on_dropdown_1)
# image_output_2.update(display_image_based_on_dropdown_2)

demo.launch()