LeonceNsh commited on
Commit
04d2663
·
verified ·
1 Parent(s): 4648e91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -238
app.py CHANGED
@@ -26,7 +26,7 @@ except Exception as e:
26
  data.columns = data.columns.str.strip().str.lower()
27
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
 
29
- # Filter out Health since Healthcare is the correct Market Segment
30
  data = data[data.industry != 'Health']
31
 
32
  # Identify the valuation column
@@ -39,8 +39,12 @@ valuation_column = valuation_columns[0]
39
  logger.info(f"Using valuation column: {valuation_column}")
40
 
41
  # Clean and prepare data
42
- data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
43
- data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
 
 
 
 
44
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
45
  data.rename(columns={
46
  "company": "Company",
@@ -69,9 +73,8 @@ def build_investor_company_mapping(df):
69
  investor_company_mapping = build_investor_company_mapping(data)
70
  logger.info("Investor to company mapping created.")
71
 
72
-
73
  # -------------------------
74
- # Valuation-Range Logic
75
  # -------------------------
76
  def filter_by_valuation_range(df, selected_valuation_range):
77
  """Filter dataframe by the specified valuation range in billions."""
@@ -79,19 +82,18 @@ def filter_by_valuation_range(df, selected_valuation_range):
79
  return df # No further filtering
80
 
81
  if selected_valuation_range == "1-5":
82
- return df[(df["Valuation_Billions"] >= 1) & (df["Valuation_Billions"] < 5)]
83
  elif selected_valuation_range == "5-10":
84
- return df[(df["Valuation_Billions"] >= 5) & (df["Valuation_Billions"] < 10)]
85
  elif selected_valuation_range == "10-15":
86
- return df[(df["Valuation_Billions"] >= 10) & (df["Valuation_Billions"] < 15)]
87
  elif selected_valuation_range == "15-20":
88
- return df[(df["Valuation_Billions"] >= 15) & (df["Valuation_Billions"] < 20)]
89
  elif selected_valuation_range == "20+":
90
- return df[df["Valuation_Billions"] >= 20]
91
  else:
92
  return df # Fallback, should never happen
93
 
94
-
95
  # Filter investors by country, industry, investor selection, company selection, and valuation range
96
  def filter_investors(
97
  selected_country,
@@ -101,259 +103,102 @@ def filter_investors(
101
  exclude_countries,
102
  exclude_industries,
103
  exclude_companies,
104
- exclude_investors,
105
  selected_valuation_range
106
  ):
107
  filtered_data = data.copy()
108
 
109
- # 1) Valuation range filter
110
- filtered_data = filter_by_valuation_range(filtered_data, selected_valuation_range)
111
-
112
- # 2) Now apply the existing filters:
113
-
114
- # Inclusion filters
115
- if selected_country != "All":
116
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
117
- if selected_industry != "All":
118
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
119
- if selected_company != "All":
120
- filtered_data = filtered_data[filtered_data["Company"] == selected_company]
121
  if selected_investors:
122
- pattern = '|'.join([re.escape(inv) for inv in selected_investors])
123
- filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
124
-
125
- # Exclusion filters
126
  if exclude_countries:
127
  filtered_data = filtered_data[~filtered_data["Country"].isin(exclude_countries)]
128
  if exclude_industries:
129
  filtered_data = filtered_data[~filtered_data["Industry"].isin(exclude_industries)]
130
  if exclude_companies:
131
  filtered_data = filtered_data[~filtered_data["Company"].isin(exclude_companies)]
132
- if exclude_investors:
133
- pattern = '|'.join([re.escape(inv) for inv in exclude_investors])
134
- filtered_data = filtered_data[~filtered_data["Select_Investors"].str.contains(pattern, na=False)]
 
135
 
136
- investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
137
- filtered_investors = list(investor_company_mapping_filtered.keys())
138
- return filtered_investors, filtered_data
139
 
 
 
 
 
140
 
141
- # Generate Plotly graph
142
- # NEW: We add selected_valuation_range so we can check if the user selected 15-20 or 20+
143
- def generate_graph(investors, filtered_data, selected_valuation_range):
144
- if not investors:
145
- logger.warning("No investors selected.")
146
- return go.Figure()
147
 
148
- # Create a color map for investors
149
- unique_investors = investors
150
- num_colors = len(unique_investors)
151
- color_palette = [
152
- "#377eb8", "#e41a1c", "#4daf4a", "#984ea3",
153
- "#ff7f00", "#ffff33", "#a65628", "#f781bf", "#999999"
154
- ]
155
- while num_colors > len(color_palette):
156
- color_palette.extend(color_palette)
157
 
158
- investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
 
159
 
160
- G = nx.Graph()
161
- for investor in investors:
162
- companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
163
- for company in companies:
164
- G.add_node(company)
165
- G.add_node(investor)
166
- G.add_edge(investor, company)
167
 
168
- pos = nx.spring_layout(G, seed=42)
169
- edge_x, edge_y = [], []
 
 
 
 
 
 
 
170
 
171
- for edge in G.edges():
 
172
  x0, y0 = pos[edge[0]]
173
  x1, y1 = pos[edge[1]]
174
- edge_x.extend([x0, x1, None])
175
- edge_y.extend([y0, y1, None])
176
-
177
- edge_trace = go.Scatter(
178
- x=edge_x,
179
- y=edge_y,
180
- line=dict(width=0.5, color='#aaaaaa'),
181
- hoverinfo='none',
182
- mode='lines'
183
- )
184
 
185
- node_x, node_y, node_text, node_textposition = [], [], [], []
186
- node_color, node_size, node_hovertext = [], [], []
187
 
188
- for node in G.nodes():
189
- x, y = pos[node]
190
- node_x.append(x)
191
- node_y.append(y)
192
- if node in investors:
193
- node_text.append(node) # Add investor labels
194
- node_color.append(investor_color_map[node])
195
- node_size.append(30)
196
- node_hovertext.append(f"Investor: {node}")
197
- node_textposition.append('top center')
198
- else:
199
- valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
200
- industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
201
- size = valuation[0] * 5 if len(valuation) > 0 and not pd.isnull(valuation[0]) else 15
202
- node_size.append(max(size, 10))
203
- node_color.append("#a6d854")
204
-
205
- # Build the hover label text
206
- hovertext = f"Company: {node}"
207
- if len(industry) > 0 and not pd.isnull(industry[0]):
208
- hovertext += f"<br>Industry: {industry[0]}"
209
- if len(valuation) > 0 and not pd.isnull(valuation[0]):
210
- hovertext += f"<br>Valuation: ${valuation[0]:.2f}B"
211
- node_hovertext.append(hovertext)
212
-
213
- # NEW: If valuation range is 15–20 or 20+, show hovertext for all companies
214
- if selected_valuation_range in ["15-20", "20+"]:
215
- node_text.append(hovertext) # show full text
216
- node_textposition.append('bottom center')
217
- else:
218
- # Old logic: only show the company name in certain conditions
219
- if (
220
- (len(valuation) > 0 and valuation[0] is not None and valuation[0] > 10) # Check if > 10B
221
- or (len(filtered_data) < 15)
222
- or (node in filtered_data.nlargest(5, "Valuation_Billions")["Company"].tolist())
223
- ):
224
- node_text.append(node) # Show just the company name
225
- node_textposition.append('bottom center')
226
- else:
227
- node_text.append("") # Hide company label
228
- node_textposition.append('bottom center')
229
-
230
- node_trace = go.Scatter(
231
- x=node_x,
232
- y=node_y,
233
- text=node_text,
234
- textposition=node_textposition,
235
- mode='markers+text',
236
- hoverinfo='text',
237
- hovertext=node_hovertext,
238
- textfont=dict(size=13), # Adjust label font size
239
- marker=dict(
240
- showscale=False,
241
- size=node_size,
242
- color=node_color,
243
- line=dict(width=0.5, color='#333333')
244
- )
245
- )
246
-
247
- # Compute total market cap
248
- total_market_cap = filtered_data["Valuation_Billions"].sum()
249
-
250
- fig = go.Figure(data=[edge_trace, node_trace])
251
-
252
- fig.update_layout(
253
- title="",
254
- titlefont_size=28,
255
- margin=dict(l=20, r=20, t=60, b=20),
256
- hovermode='closest',
257
- width=1200,
258
- height=800,
259
- autosize=True,
260
- xaxis=dict(showgrid=False, zeroline=False, visible=False),
261
- yaxis=dict(showgrid=False, zeroline=False, visible=False),
262
- showlegend=False, # Hide the legend to maximize space
263
- annotations=[
264
- dict(
265
- x=0.5, y=1.1, xref='paper', yref='paper',
266
- text=f"Combined Market Cap: ${total_market_cap:.1f} Billions",
267
- showarrow=False, font=dict(size=14), xanchor='center'
268
- )
269
- ]
270
- )
271
 
272
  return fig
273
 
274
-
275
- # Gradio app function
276
- def app(
277
- selected_country,
278
- selected_industry,
279
- selected_company,
280
- selected_investors,
281
- exclude_countries,
282
- exclude_industries,
283
- exclude_companies,
284
- exclude_investors,
285
- selected_valuation_range
286
- ):
287
- investors, filtered_data = filter_investors(
288
- selected_country,
289
- selected_industry,
290
- selected_investors,
291
- selected_company,
292
- exclude_countries,
293
- exclude_industries,
294
- exclude_companies,
295
- exclude_investors,
296
- selected_valuation_range
297
  )
298
- if not investors:
299
- return go.Figure()
300
- # NEW: Pass valuation_range to generate_graph
301
- graph = generate_graph(investors, filtered_data, selected_valuation_range)
302
- return graph
303
-
304
-
305
- def main():
306
- country_list = ["All"] + sorted(data["Country"].dropna().unique())
307
- industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
308
- company_list = ["All"] + sorted(data["Company"].dropna().unique())
309
- investor_list = sorted(investor_company_mapping.keys())
310
-
311
- # Valuation range choices
312
- valuation_ranges = ["All", "1-5", "5-10", "10-15", "15-20", "20+"]
313
-
314
- with gr.Blocks(title="Venture Networks Visualization") as demo:
315
- gr.Markdown("# Venture Networks Visualization")
316
-
317
- with gr.Row():
318
- country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
319
- industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
320
- company_filter = gr.Dropdown(choices=company_list, label="Company", value="All")
321
- investor_filter = gr.Dropdown(choices=investor_list, label="Select Investors", value=[], multiselect=True)
322
-
323
- with gr.Row():
324
- valuation_range_filter = gr.Dropdown(
325
- choices=valuation_ranges,
326
- label="Valuation Range (Billions)",
327
- value="All"
328
- )
329
- exclude_country_filter = gr.Dropdown(choices=country_list[1:], label="Exclude Country", value=[], multiselect=True)
330
- exclude_industry_filter = gr.Dropdown(choices=industry_list[1:], label="Exclude Industry", value=[], multiselect=True)
331
- exclude_company_filter = gr.Dropdown(choices=company_list[1:], label="Exclude Company", value=[], multiselect=True)
332
- exclude_investor_filter = gr.Dropdown(choices=investor_list, label="Exclude Investors", value=[], multiselect=True)
333
-
334
- graph_output = gr.Plot(label="Network Graph")
335
-
336
- inputs = [
337
- country_filter,
338
- industry_filter,
339
- company_filter,
340
- investor_filter,
341
- exclude_country_filter,
342
- exclude_industry_filter,
343
- exclude_company_filter,
344
- exclude_investor_filter,
345
- valuation_range_filter
346
- ]
347
- outputs = [graph_output]
348
-
349
- # Set up event triggers for all inputs
350
- for input_control in inputs:
351
- input_control.change(app, inputs, outputs)
352
-
353
- gr.Markdown("**Instructions:** Use the dropdowns to filter the network graph. For valuation ranges 15–20 or 20+, you’ll see each company's info label without hovering.")
354
-
355
- demo.launch()
356
-
357
-
358
- if __name__ == "__main__":
359
- main()
 
26
  data.columns = data.columns.str.strip().str.lower()
27
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
 
29
+ # Filter out 'Health' since 'Healthcare' is the correct Market Segment
30
  data = data[data.industry != 'Health']
31
 
32
  # Identify the valuation column
 
39
  logger.info(f"Using valuation column: {valuation_column}")
40
 
41
  # Clean and prepare data
42
+ data["valuation_billions"] = data[valuation_column].apply(
43
+ lambda x: float(re.sub(r"[^\d.]", "", str(x))) if pd.notnull(x) else 0
44
+ )
45
+ data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
46
+
47
+ # Clean string columns
48
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
49
  data.rename(columns={
50
  "company": "Company",
 
73
  investor_company_mapping = build_investor_company_mapping(data)
74
  logger.info("Investor to company mapping created.")
75
 
 
76
  # -------------------------
77
+ # Valuation-Range Logic
78
  # -------------------------
79
  def filter_by_valuation_range(df, selected_valuation_range):
80
  """Filter dataframe by the specified valuation range in billions."""
 
82
  return df # No further filtering
83
 
84
  if selected_valuation_range == "1-5":
85
+ return df[(df["valuation_billions"] >= 1) & (df["valuation_billions"] < 5)]
86
  elif selected_valuation_range == "5-10":
87
+ return df[(df["valuation_billions"] >= 5) & (df["valuation_billions"] < 10)]
88
  elif selected_valuation_range == "10-15":
89
+ return df[(df["valuation_billions"] >= 10) & (df["valuation_billions"] < 15)]
90
  elif selected_valuation_range == "15-20":
91
+ return df[(df["valuation_billions"] >= 15) & (df["valuation_billions"] < 20)]
92
  elif selected_valuation_range == "20+":
93
+ return df[df["valuation_billions"] >= 20]
94
  else:
95
  return df # Fallback, should never happen
96
 
 
97
  # Filter investors by country, industry, investor selection, company selection, and valuation range
98
  def filter_investors(
99
  selected_country,
 
103
  exclude_countries,
104
  exclude_industries,
105
  exclude_companies,
 
106
  selected_valuation_range
107
  ):
108
  filtered_data = data.copy()
109
 
110
+ # Apply filters
111
+ if selected_country:
 
 
 
 
 
112
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
113
+ if selected_industry:
114
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
 
 
115
  if selected_investors:
116
+ filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(selected_investors, na=False)]
117
+ if selected_company:
118
+ filtered_data = filtered_data[filtered_data["Company"] == selected_company]
 
119
  if exclude_countries:
120
  filtered_data = filtered_data[~filtered_data["Country"].isin(exclude_countries)]
121
  if exclude_industries:
122
  filtered_data = filtered_data[~filtered_data["Industry"].isin(exclude_industries)]
123
  if exclude_companies:
124
  filtered_data = filtered_data[~filtered_data["Company"].isin(exclude_companies)]
125
+ if selected_valuation_range:
126
+ filtered_data = filter_by_valuation_range(filtered_data, selected_valuation_range)
127
+
128
+ return filtered_data
129
 
130
+ # Create the graph visualization
131
+ def create_network_graph(filtered_data):
132
+ graph = nx.Graph()
133
 
134
+ # Add nodes and edges
135
+ for _, row in filtered_data.iterrows():
136
+ company = row['Company']
137
+ venture_firm = row['Select_Investors']
138
 
139
+ # Add company node (green color, size based on valuation)
140
+ graph.add_node(company, node_color='green', node_size=row['valuation_billions'] * 10)
 
 
 
 
141
 
142
+ # Add venture firm node (different color, fixed size)
143
+ graph.add_node(venture_firm, node_color='blue', node_size=30)
 
 
 
 
 
 
 
144
 
145
+ # Add an edge between the company and the venture firm
146
+ graph.add_edge(company, venture_firm)
147
 
148
+ # Generate visualization
149
+ pos = nx.spring_layout(graph) # Layout for positioning
150
+ fig = go.Figure()
 
 
 
 
151
 
152
+ # Add nodes
153
+ for node, attrs in graph.nodes(data=True):
154
+ fig.add_trace(go.Scatter(
155
+ x=[pos[node][0]], y=[pos[node][1]],
156
+ mode='markers+text',
157
+ marker=dict(size=attrs['node_size'], color=attrs['node_color']),
158
+ text=node,
159
+ textposition='top center'
160
+ ))
161
 
162
+ # Add edges
163
+ for edge in graph.edges:
164
  x0, y0 = pos[edge[0]]
165
  x1, y1 = pos[edge[1]]
166
+ fig.add_trace(go.Scatter(
167
+ x=[x0, x1, None], y=[y0, y1, None],
168
+ mode='lines',
169
+ line=dict(width=1, color='grey')
170
+ ))
 
 
 
 
 
171
 
172
+ fig.update_layout(title="Company-Investor Network", showlegend=False)
 
173
 
174
+ # Add the note to the plot
175
+ note = ("Note: All companies are in green while venture firms have different colors. "
176
+ "The diameter of the company circle varies proportionate to the valuation of the corresponding company.")
177
+ logger.info(note)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  return fig
180
 
181
+ # Gradio interface
182
+ def display_network(selected_country, selected_industry, selected_investors, selected_company,
183
+ exclude_countries, exclude_industries, exclude_companies, selected_valuation_range):
184
+ filtered_data = filter_investors(
185
+ selected_country, selected_industry, selected_investors, selected_company,
186
+ exclude_countries, exclude_industries, exclude_companies, selected_valuation_range
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
+ return create_network_graph(filtered_data)
189
+
190
+ # Gradio interface inputs
191
+ inputs = [
192
+ gr.Textbox(label="Select Country"),
193
+ gr.Textbox(label="Select Industry"),
194
+ gr.Textbox(label="Select Investors"),
195
+ gr.Textbox(label="Select Company"),
196
+ gr.Textbox(label="Exclude Countries"),
197
+ gr.Textbox(label="Exclude Industries"),
198
+ gr.Textbox(label="Exclude Companies"),
199
+ gr.Radio(choices=["All", "1-5", "5-10", "10-15", "15-20", "20+"], label="Valuation Range")
200
+ ]
201
+
202
+ # Launch the Gradio interface
203
+ interface = gr.Interface(fn=display_network, inputs=inputs, outputs="plot", live=True)
204
+ interface.launch()