harishB97 commited on
Commit
b63e42a
·
verified ·
1 Parent(s): 08b8cdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -53
app.py CHANGED
@@ -128,71 +128,154 @@ def get_tree(imgpath):
128
  ROOT = None
129
 
130
 
131
- def display_tree():
132
- nodes = []
133
- edges = []
134
- positions = {}
135
 
136
- root = ROOT
137
 
138
- def traverse(node, depth=0, index=0):
139
 
140
- if depth >= 3:
141
- return
142
 
143
- if node not in nodes:
144
- nodes.append(node)
145
- idx = nodes.index(node)
146
- positions[idx] = (depth * 1, index * 1 - len(nodes) / 2) # Adjusted the multipliers for depth and index
147
-
148
- for child in node.children:
149
- if child not in nodes:
150
- nodes.append(child)
151
- child_idx = nodes.index(child)
152
- edges.append((idx, child_idx))
153
- traverse(child, depth + 1, index + len(node.children) / 2) # Recursively traverse to set positions
154
-
155
- traverse(root)
156
-
157
- edge_x = []
158
- edge_y = []
159
- for edge in edges:
160
- x0, y0 = positions[edge[0]]
161
- x1, y1 = positions[edge[1]]
162
- edge_x.extend([x0, x1, None])
163
- edge_y.extend([y0, y1, None])
164
 
165
- edge_trace = go.Scatter(
166
- x=edge_x, y=edge_y,
167
- line=dict(width=2, color='Black'),
168
- hoverinfo='none',
169
- mode='lines')
 
 
170
 
171
- node_x = [pos[0] for pos in positions.values()]
172
- node_y = [pos[1] for pos in positions.values()]
 
 
 
173
 
174
- node_trace = go.Scatter(
175
- x=node_x, y=node_y,
176
- mode='markers+text',
177
- hoverinfo='text',
178
- marker=dict(showscale=False, size=10, color='Goldenrod'),
179
- text=[node.name for node in nodes],
180
- textposition="top center"
181
- )
 
 
 
182
 
183
- layout = go.Layout(
184
- title="Tree Visualization",
185
- showlegend=False,
186
- hovermode='closest',
187
- margin=dict(b=0, l=0, r=0, t=40),
188
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
189
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  )
191
 
192
- fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
193
- return fig
 
 
 
194
 
 
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
 
198
  def get_protoIDs(nodename):
 
128
  ROOT = None
129
 
130
 
131
+ # def display_tree():
132
+ # nodes = []
133
+ # edges = []
134
+ # positions = {}
135
 
136
+ # root = ROOT
137
 
138
+ # def traverse(node, depth=0, index=0):
139
 
140
+ # if depth >= 3:
141
+ # return
142
 
143
+ # if node not in nodes:
144
+ # nodes.append(node)
145
+ # idx = nodes.index(node)
146
+ # positions[idx] = (depth * 1, index * 1 - len(nodes) / 2) # Adjusted the multipliers for depth and index
147
+
148
+ # for child in node.children:
149
+ # if child not in nodes:
150
+ # nodes.append(child)
151
+ # child_idx = nodes.index(child)
152
+ # edges.append((idx, child_idx))
153
+ # traverse(child, depth + 1, index + len(node.children) / 2) # Recursively traverse to set positions
154
+
155
+ # traverse(root)
 
 
 
 
 
 
 
 
156
 
157
+ # edge_x = []
158
+ # edge_y = []
159
+ # for edge in edges:
160
+ # x0, y0 = positions[edge[0]]
161
+ # x1, y1 = positions[edge[1]]
162
+ # edge_x.extend([x0, x1, None])
163
+ # edge_y.extend([y0, y1, None])
164
 
165
+ # edge_trace = go.Scatter(
166
+ # x=edge_x, y=edge_y,
167
+ # line=dict(width=2, color='Black'),
168
+ # hoverinfo='none',
169
+ # mode='lines')
170
 
171
+ # node_x = [pos[0] for pos in positions.values()]
172
+ # node_y = [pos[1] for pos in positions.values()]
173
+
174
+ # node_trace = go.Scatter(
175
+ # x=node_x, y=node_y,
176
+ # mode='markers+text',
177
+ # hoverinfo='text',
178
+ # marker=dict(showscale=False, size=10, color='Goldenrod'),
179
+ # text=[node.name for node in nodes],
180
+ # textposition="top center"
181
+ # )
182
 
183
+ # layout = go.Layout(
184
+ # title="Tree Visualization",
185
+ # showlegend=False,
186
+ # hovermode='closest',
187
+ # margin=dict(b=0, l=0, r=0, t=40),
188
+ # xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
189
+ # yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
190
+ # )
191
+
192
+ # fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
193
+ # return fig
194
+
195
+
196
+
197
+
198
+
199
+ def create_edge_list(node, edge_list=None, name_map=None):
200
+ if edge_list is None:
201
+ edge_list = []
202
+ if name_map is None:
203
+ name_map = {}
204
+
205
+ if node.name not in name_map:
206
+ name_map[node.name] = len(name_map)
207
+
208
+ for child in node.children:
209
+ if child.name not in name_map:
210
+ name_map[child.name] = len(name_map)
211
+ edge_list.append((name_map[node.name], name_map[child.name]))
212
+ create_edge_list(child, edge_list, name_map)
213
+
214
+ return edge_list, list(name_map.keys())
215
+
216
+ def display_tree():
217
+ root = ROOT
218
+ edge_list, node_names = create_edge_list(root)
219
+
220
+ # Create an igraph Graph from edge list
221
+ g = ig.Graph(edges=edge_list, directed=True)
222
+
223
+ # Use the Reingold-Tilford tree layout
224
+ layout = g.layout('rt', root=[0])
225
+
226
+ # Scale the layout to make the tree more compact
227
+ scale_factor = 0.2 # Adjust this factor as needed
228
+ layout_coords = [(coord[0] * scale_factor, coord[1] * scale_factor) for coord in layout.coords]
229
+
230
+ x_coords = [coord[0] for coord in layout_coords]
231
+ y_coords = [-coord[1] for coord in layout_coords] # invert y-axis for a top-down tree view
232
+
233
+ # Create Plotly traces for nodes and edges
234
+ edge_trace = go.Scatter(
235
+ x=[None],
236
+ y=[None],
237
+ line=dict(width=2, color='#888'),
238
+ hoverinfo='none',
239
+ mode='lines'
240
  )
241
 
242
+ for v1, v2 in g.get_edgelist():
243
+ x0, y0 = layout_coords[v1]
244
+ x1, y1 = layout_coords[v2]
245
+ # edge_trace['x'] += [x0, x1, None]
246
+ # edge_trace['y'] += [-y0, -y1, None] # invert y-axis
247
 
248
+ edge_trace['x'] += tuple(list(edge_trace['x']) + [x0, x1, None])
249
+ edge_trace['y'] += tuple(list(edge_trace['y']) + [-y0, -y1, None]) # invert y-axis
250
 
251
+ node_trace = go.Scatter(
252
+ x=x_coords, y=y_coords,
253
+ text=node_names,
254
+ mode='markers+text',
255
+ hoverinfo='text',
256
+ textposition='top center',
257
+ marker=dict(
258
+ showscale=False,
259
+ color='Blue',
260
+ size=10,
261
+ line_width=2),
262
+ textfont=dict(
263
+ size=12, # Increase the font size as needed
264
+ color='Black'
265
+ )
266
+ )
267
+
268
+ # Create a Plotly figure
269
+ fig = go.Figure(data=[edge_trace, node_trace],
270
+ layout=go.Layout(
271
+ showlegend=False,
272
+ hovermode='closest',
273
+ margin=dict(b=0, l=0, r=0, t=0),
274
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
275
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
276
+ title="Tree Visualization"
277
+ ))
278
+ return fig
279
 
280
 
281
  def get_protoIDs(nodename):