LeonceNsh commited on
Commit
77dc67f
·
verified ·
1 Parent(s): 01ca6ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -74
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import pandas as pd
2
  import networkx as nx
3
- import matplotlib.pyplot as plt
4
  from io import BytesIO
5
  from PIL import Image
6
  import gradio as gr
@@ -108,12 +108,12 @@ def filter_investors_by_country_and_industry(selected_country, selected_industry
108
 
109
  return list(investor_valuations.keys()), filtered_data
110
 
111
- # Function to generate the graph and return node information
112
- def generate_graph_and_get_node_info(selected_investors, filtered_data, clicked_node=None):
113
  if not selected_investors:
114
- logger.warning("No investors selected. Returning None for graph.")
115
- return None, "No investors selected."
116
-
117
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
118
  filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors if inv in investor_company_mapping_filtered}
119
 
@@ -125,66 +125,94 @@ def generate_graph_and_get_node_info(selected_investors, filtered_data, clicked_
125
  for company in companies:
126
  G.add_edge(investor, company)
127
 
128
- # Node size based on valuation
129
- max_valuation = filtered_data["Valuation_Billions"].max()
130
- node_sizes = []
131
- for node in G.nodes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if node in filtered_mapping:
133
- node_sizes.append(1500) # Fixed size for investors
 
 
 
134
  else:
135
  valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
136
- size = (valuation / max_valuation) * 1500 if max_valuation else 100
137
- node_sizes.append(size)
138
-
139
- # Node color: Investors (orange), Companies (green)
140
- node_colors = ["#FF8C00" if node in filtered_mapping else "#32CD32" for node in G.nodes]
141
 
142
- # Draw the graph
143
- plt.figure(figsize=(15, 15))
144
- pos = nx.spring_layout(G, k=0.2, seed=42)
145
- nx.draw(
146
- G, pos,
147
- with_labels=True,
148
- node_size=node_sizes,
149
- node_color=node_colors,
150
- font_size=10,
151
- edge_color="#A9A9A9", # Light gray edges
152
- alpha=0.9
 
 
153
  )
154
 
155
- # Legend
156
- from matplotlib.lines import Line2D
157
- legend_elements = [
158
- Line2D([0], [0], marker='o', color='w', label='Investor', markersize=10, markerfacecolor='#FF8C00'),
159
- Line2D([0], [0], marker='o', color='w', label='Company', markersize=10, markerfacecolor='#32CD32')
160
- ]
161
- plt.legend(handles=legend_elements, loc='upper left')
162
-
163
- plt.title("Venture Network Visualization", fontsize=20)
164
- plt.axis("off")
165
-
166
- # Save plot to BytesIO
167
- buf = BytesIO()
168
- plt.savefig(buf, format="png", bbox_inches="tight")
169
- plt.close()
170
- buf.seek(0)
171
-
172
- # Get node information if clicked
173
- node_info = "No node clicked."
174
- if clicked_node:
175
- if clicked_node in filtered_data["Company"].values:
176
- valuation = filtered_data.loc[filtered_data["Company"] == clicked_node, "Valuation_Billions"].iloc[0]
177
- node_info = f"Company: {clicked_node}, Valuation: ${valuation}B"
178
- elif clicked_node in filtered_mapping:
179
- node_info = f"Investor: {clicked_node}"
180
-
181
- return Image.open(buf), node_info
182
-
183
- # Gradio app function
184
- def app(selected_country, selected_industry, selected_investors, clicked_node=None):
185
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
186
- graph_image, node_info = generate_graph_and_get_node_info(selected_investors, filtered_data, clicked_node)
187
- return graph_image, node_info, gr.update(choices=investor_list, value=selected_investors)
 
 
 
 
 
 
188
 
189
  # Gradio Interface
190
  def main():
@@ -196,33 +224,57 @@ def main():
196
 
197
  with gr.Blocks() as demo:
198
  with gr.Row():
 
199
  country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="US")
200
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="Enterprise Tech")
201
 
202
- filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors")
203
- clicked_node_input = gr.Textbox(label="Clicked Node (Enter name to simulate a click)", placeholder="Enter a node name")
204
- graph_output = gr.Image(type="pil", label="Venture Network Graph")
205
- node_info_output = gr.Textbox(label="Node Information", interactive=False)
 
206
 
 
207
  country_filter.change(
208
  app,
209
- inputs=[country_filter, industry_filter, filtered_investor_list, clicked_node_input],
210
- outputs=[graph_output, node_info_output, filtered_investor_list]
211
  )
212
  industry_filter.change(
213
  app,
214
- inputs=[country_filter, industry_filter, filtered_investor_list, clicked_node_input],
215
- outputs=[graph_output, node_info_output, filtered_investor_list]
216
  )
 
 
217
  filtered_investor_list.change(
218
- app,
219
- inputs=[country_filter, industry_filter, filtered_investor_list, clicked_node_input],
220
- outputs=[graph_output, node_info_output, filtered_investor_list]
221
  )
222
- clicked_node_input.change(
223
- app,
224
- inputs=[country_filter, industry_filter, filtered_investor_list, clicked_node_input],
225
- outputs=[graph_output, node_info_output, filtered_investor_list]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  )
227
 
228
  demo.launch()
 
1
  import pandas as pd
2
  import networkx as nx
3
+ import plotly.graph_objects as go
4
  from io import BytesIO
5
  from PIL import Image
6
  import gradio as gr
 
108
 
109
  return list(investor_valuations.keys()), filtered_data
110
 
111
+ # Function to generate the Plotly graph
112
+ def generate_graph(selected_investors, filtered_data):
113
  if not selected_investors:
114
+ logger.warning("No investors selected. Returning empty figure.")
115
+ return go.Figure()
116
+
117
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
118
  filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors if inv in investor_company_mapping_filtered}
119
 
 
125
  for company in companies:
126
  G.add_edge(investor, company)
127
 
128
+ # Generate positions using spring layout
129
+ pos = nx.spring_layout(G, k=0.2, seed=42)
130
+
131
+ # Prepare Plotly traces
132
+ edge_x = []
133
+ edge_y = []
134
+ for edge in G.edges():
135
+ x0, y0 = pos[edge[0]]
136
+ x1, y1 = pos[edge[1]]
137
+ edge_x += [x0, x1, None]
138
+ edge_y += [y0, y1, None]
139
+
140
+ edge_trace = go.Scatter(
141
+ x=edge_x, y=edge_y,
142
+ line=dict(width=0.5, color='#888'),
143
+ hoverinfo='none',
144
+ mode='lines'
145
+ )
146
+
147
+ node_x = []
148
+ node_y = []
149
+ node_text = []
150
+ node_size = []
151
+ node_color = []
152
+ customdata = []
153
+ for node in G.nodes():
154
+ x, y = pos[node]
155
+ node_x.append(x)
156
+ node_y.append(y)
157
  if node in filtered_mapping:
158
+ node_text.append(f"Investor: {node}")
159
+ node_size.append(20) # Investors have larger size
160
+ node_color.append('orange')
161
+ customdata.append(None) # Investors do not have a single valuation
162
  else:
163
  valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
164
+ node_text.append(f"Company: {node}<br>Valuation: ${valuation}B")
165
+ node_size.append(10 + (valuation / filtered_data["Valuation_Billions"].max()) * 30 if filtered_data["Valuation_Billions"].max() else 10)
166
+ node_color.append('green')
167
+ customdata.append(f"${valuation}B")
 
168
 
169
+ node_trace = go.Scatter(
170
+ x=node_x, y=node_y,
171
+ mode='markers',
172
+ hoverinfo='text',
173
+ text=node_text,
174
+ customdata=customdata,
175
+ marker=dict(
176
+ showscale=False,
177
+ colorscale='YlGnBu',
178
+ color=node_color,
179
+ size=node_size,
180
+ line_width=2
181
+ )
182
  )
183
 
184
+ fig = go.Figure(data=[edge_trace, node_trace],
185
+ layout=go.Layout(
186
+ title='Venture Network Visualization',
187
+ titlefont_size=16,
188
+ showlegend=False,
189
+ hovermode='closest',
190
+ margin=dict(b=20,l=5,r=5,t=40),
191
+ annotations=[ dict(
192
+ text="",
193
+ showarrow=False,
194
+ xref="paper", yref="paper") ],
195
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
196
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
197
+ )
198
+
199
+ fig.update_traces(marker=dict(line=dict(width=0.5, color='white')), selector=dict(mode='markers'))
200
+
201
+ logger.info("Plotly graph generated successfully.")
202
+
203
+ return fig
204
+
205
+ # Gradio app function to update CheckboxGroup and filtered data
206
+ def app(selected_country, selected_industry):
 
 
 
 
 
 
 
207
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
208
+ logger.info("Updating CheckboxGroup and filtered data holder.")
209
+
210
+ # Use gr.update() to create an update dictionary for the CheckboxGroup
211
+ return gr.update(
212
+ choices=investor_list,
213
+ value=investor_list,
214
+ visible=True
215
+ ), filtered_data
216
 
217
  # Gradio Interface
218
  def main():
 
224
 
225
  with gr.Blocks() as demo:
226
  with gr.Row():
227
+ # Set default value to "US" for country and "Enterprise Tech" for industry
228
  country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="US")
229
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="Enterprise Tech")
230
 
231
+ filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors", visible=False)
232
+ graph_output = gr.Plot(label="Venture Network Graph")
233
+ valuation_display = gr.Markdown(value="Click on a company node to see its valuation.", label="Company Valuation")
234
+
235
+ filtered_data_holder = gr.State()
236
 
237
+ # Event handlers for filters
238
  country_filter.change(
239
  app,
240
+ inputs=[country_filter, industry_filter],
241
+ outputs=[filtered_investor_list, filtered_data_holder]
242
  )
243
  industry_filter.change(
244
  app,
245
+ inputs=[country_filter, industry_filter],
246
+ outputs=[filtered_investor_list, filtered_data_holder]
247
  )
248
+
249
+ # Generate graph when investors are selected
250
  filtered_investor_list.change(
251
+ generate_graph,
252
+ inputs=[filtered_investor_list, filtered_data_holder],
253
+ outputs=graph_output
254
  )
255
+
256
+ # Handle plot click to display valuation
257
+ def display_valuation(plotly_click):
258
+ if plotly_click is None:
259
+ return "Click on a company node to see its valuation."
260
+ point = plotly_click
261
+ if 'text' in point and point['text']:
262
+ text = point['text']
263
+ if "Company:" in text:
264
+ # Extract valuation
265
+ parts = text.split("<br>")
266
+ company_part = parts[0]
267
+ valuation_part = parts[1]
268
+ company = company_part.replace("Company: ", "")
269
+ valuation = valuation_part.replace("Valuation: ", "")
270
+ return f"**{company}** has a valuation of **{valuation}**."
271
+ return "Click on a company node to see its valuation."
272
+
273
+ graph_output.event(
274
+ "plotly_click",
275
+ fn=display_valuation,
276
+ inputs=graph_output,
277
+ outputs=valuation_display
278
  )
279
 
280
  demo.launch()