LeonceNsh commited on
Commit
a46a8e7
·
verified ·
1 Parent(s): 5240604

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -65
app.py CHANGED
@@ -26,7 +26,7 @@ except Exception as e:
26
  data.columns = data.columns.str.strip().str.lower()
27
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
 
29
- # Filter out 'Health' since 'Healthcare' is the correct Market Segment
30
  data = data[data.industry != 'Health']
31
 
32
  # Identify the valuation column
@@ -39,70 +39,321 @@ valuation_column = valuation_columns[0]
39
  logger.info(f"Using valuation column: {valuation_column}")
40
 
41
  # Clean and prepare data
42
- data["valuation_billions"] = data[valuation_column].apply(
43
- lambda x: float(re.sub(r"[^0-9.]", "", str(x))) / 1e9 if pd.notnull(x) else 0
44
- )
45
- logger.info("Valuation column cleaned and converted to billions.")
46
-
47
- # Create a graph
48
- G = nx.Graph()
49
-
50
- for _, row in data.iterrows():
51
- company_name = row["company"]
52
- valuation = row["valuation_billions"]
53
- industry = row["industry"]
54
-
55
- # Add company node
56
- G.add_node(
57
- company_name,
58
- size=valuation,
59
- color="green" if industry == "Venture Firm" else "blue",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
62
- # Add connections based on relationships (assume relationships column exists)
63
- if "relationships" in data.columns:
64
- relationships = str(row["relationships"]).split(";")
65
- for relation in relationships:
66
- G.add_edge(company_name, relation.strip())
67
-
68
- # Create Plotly visualization
69
- node_sizes = [G.nodes[node]["size"] * 50 for node in G.nodes]
70
- node_colors = [G.nodes[node]["color"] for node in G.nodes]
71
-
72
- pos = nx.spring_layout(G)
73
- x_coords = [pos[node][0] for node in G.nodes]
74
- y_coords = [pos[node][1] for node in G.nodes]
75
-
76
- fig = go.Figure()
77
-
78
- fig.add_trace(
79
- go.Scatter(
80
- x=x_coords,
81
- y=y_coords,
82
- mode="markers+text",
83
- marker=dict(size=node_sizes, color=node_colors, opacity=0.8),
84
- text=list(G.nodes),
85
- textposition="top center",
 
86
  )
87
- )
88
-
89
- fig.update_layout(
90
- title="Company Network Visualization",
91
- xaxis=dict(showgrid=False, zeroline=False),
92
- yaxis=dict(showgrid=False, zeroline=False),
93
- showlegend=False,
94
- )
95
-
96
- # Note: All companies are in green while venture firms have different colors.
97
- # The diameter of the company circle varies proportionate to the valuation of the corresponding company.
98
-
99
- # Create Gradio interface
100
- def display_network():
101
- return fig.to_html()
102
-
103
- gr.Interface(
104
- fn=display_network,
105
- inputs=[],
106
- outputs="html",
107
- title="Company Network Visualization",
108
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  data.columns = data.columns.str.strip().str.lower()
27
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
 
29
+ # Filter out Health since Healthcare is the correct Market Segment
30
  data = data[data.industry != 'Health']
31
 
32
  # Identify the valuation column
 
39
  logger.info(f"Using valuation column: {valuation_column}")
40
 
41
  # Clean and prepare data
42
+ data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
43
+ data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
44
+ data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
45
+ data.rename(columns={
46
+ "company": "Company",
47
+ "date_joined": "Date_Joined",
48
+ "country": "Country",
49
+ "city": "City",
50
+ "industry": "Industry",
51
+ "select_investors": "Select_Investors"
52
+ }, inplace=True)
53
+
54
+ logger.info("Data cleaned and columns renamed.")
55
+
56
+ # Build investor-company mapping
57
+ def build_investor_company_mapping(df):
58
+ mapping = {}
59
+ for _, row in df.iterrows():
60
+ company = row["Company"]
61
+ investors = row["Select_Investors"]
62
+ if pd.notnull(investors):
63
+ for investor in investors.split(","):
64
+ investor = investor.strip()
65
+ if investor:
66
+ mapping.setdefault(investor, []).append(company)
67
+ return mapping
68
+
69
+ investor_company_mapping = build_investor_company_mapping(data)
70
+ logger.info("Investor to company mapping created.")
71
+
72
+
73
+ # -------------------------
74
+ # Valuation-Range Logic
75
+ # -------------------------
76
+ def filter_by_valuation_range(df, selected_valuation_range):
77
+ """Filter dataframe by the specified valuation range in billions."""
78
+ if selected_valuation_range == "All":
79
+ return df # No further filtering
80
+
81
+ if selected_valuation_range == "1-5":
82
+ return df[(df["Valuation_Billions"] >= 1) & (df["Valuation_Billions"] < 5)]
83
+ elif selected_valuation_range == "5-10":
84
+ return df[(df["Valuation_Billions"] >= 5) & (df["Valuation_Billions"] < 10)]
85
+ elif selected_valuation_range == "10-15":
86
+ return df[(df["Valuation_Billions"] >= 10) & (df["Valuation_Billions"] < 15)]
87
+ elif selected_valuation_range == "15-20":
88
+ return df[(df["Valuation_Billions"] >= 15) & (df["Valuation_Billions"] < 20)]
89
+ elif selected_valuation_range == "20+":
90
+ return df[df["Valuation_Billions"] >= 20]
91
+ else:
92
+ return df # Fallback, should never happen
93
+
94
+
95
+ # Filter investors by country, industry, investor selection, company selection, and valuation range
96
+ def filter_investors(
97
+ selected_country,
98
+ selected_industry,
99
+ selected_investors,
100
+ selected_company,
101
+ exclude_countries,
102
+ exclude_industries,
103
+ exclude_companies,
104
+ exclude_investors,
105
+ selected_valuation_range
106
+ ):
107
+ filtered_data = data.copy()
108
+
109
+ # 1) Valuation range filter
110
+ filtered_data = filter_by_valuation_range(filtered_data, selected_valuation_range)
111
+
112
+ # 2) Now apply the existing filters:
113
+
114
+ # Inclusion filters
115
+ if selected_country != "All":
116
+ filtered_data = filtered_data[filtered_data["Country"] == selected_country]
117
+ if selected_industry != "All":
118
+ filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
119
+ if selected_company != "All":
120
+ filtered_data = filtered_data[filtered_data["Company"] == selected_company]
121
+ if selected_investors:
122
+ pattern = '|'.join([re.escape(inv) for inv in selected_investors])
123
+ filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
124
+
125
+ # Exclusion filters
126
+ if exclude_countries:
127
+ filtered_data = filtered_data[~filtered_data["Country"].isin(exclude_countries)]
128
+ if exclude_industries:
129
+ filtered_data = filtered_data[~filtered_data["Industry"].isin(exclude_industries)]
130
+ if exclude_companies:
131
+ filtered_data = filtered_data[~filtered_data["Company"].isin(exclude_companies)]
132
+ if exclude_investors:
133
+ pattern = '|'.join([re.escape(inv) for inv in exclude_investors])
134
+ filtered_data = filtered_data[~filtered_data["Select_Investors"].str.contains(pattern, na=False)]
135
+
136
+ investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
137
+ filtered_investors = list(investor_company_mapping_filtered.keys())
138
+ return filtered_investors, filtered_data
139
+
140
+
141
+ # Generate Plotly graph
142
+ # NEW: We add selected_valuation_range so we can check if the user selected 15-20 or 20+
143
+ def generate_graph(investors, filtered_data, selected_valuation_range):
144
+ if not investors:
145
+ logger.warning("No investors selected.")
146
+ return go.Figure()
147
+
148
+ # Create a color map for investors
149
+ unique_investors = investors
150
+ num_colors = len(unique_investors)
151
+ color_palette = [
152
+ "#377eb8", "#e41a1c", "#4daf4a", "#984ea3",
153
+ "#ff7f00", "#ffff33", "#a65628", "#f781bf", "#999999"
154
+ ]
155
+ while num_colors > len(color_palette):
156
+ color_palette.extend(color_palette)
157
+
158
+ investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
159
+
160
+ G = nx.Graph()
161
+ for investor in investors:
162
+ companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
163
+ for company in companies:
164
+ G.add_node(company)
165
+ G.add_node(investor)
166
+ G.add_edge(investor, company)
167
+
168
+ pos = nx.spring_layout(G, seed=42)
169
+ edge_x, edge_y = [], []
170
+
171
+ for edge in G.edges():
172
+ x0, y0 = pos[edge[0]]
173
+ x1, y1 = pos[edge[1]]
174
+ edge_x.extend([x0, x1, None])
175
+ edge_y.extend([y0, y1, None])
176
+
177
+ edge_trace = go.Scatter(
178
+ x=edge_x,
179
+ y=edge_y,
180
+ line=dict(width=0.5, color='#aaaaaa'),
181
+ hoverinfo='none',
182
+ mode='lines'
183
+ )
184
+
185
+ node_x, node_y, node_text, node_textposition = [], [], [], []
186
+ node_color, node_size, node_hovertext = [], [], []
187
+
188
+ for node in G.nodes():
189
+ x, y = pos[node]
190
+ node_x.append(x)
191
+ node_y.append(y)
192
+ if node in investors:
193
+ node_text.append(node) # Add investor labels
194
+ node_color.append(investor_color_map[node])
195
+ node_size.append(30)
196
+ node_hovertext.append(f"Investor: {node}")
197
+ node_textposition.append('top center')
198
+ else:
199
+ valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
200
+ industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
201
+ size = valuation[0] * 5 if len(valuation) > 0 and not pd.isnull(valuation[0]) else 15
202
+ node_size.append(max(size, 10))
203
+ node_color.append("#a6d854")
204
+
205
+ # Build the hover label text
206
+ hovertext = f"Company: {node}"
207
+ if len(industry) > 0 and not pd.isnull(industry[0]):
208
+ hovertext += f"<br>Industry: {industry[0]}"
209
+ if len(valuation) > 0 and not pd.isnull(valuation[0]):
210
+ hovertext += f"<br>Valuation: ${valuation[0]:.2f}B"
211
+ node_hovertext.append(hovertext)
212
+
213
+ # NEW: If valuation range is 15–20 or 20+, show hovertext for all companies
214
+ if selected_valuation_range in ["15-20", "20+"]:
215
+ node_text.append(hovertext) # show full text
216
+ node_textposition.append('bottom center')
217
+ else:
218
+ # Old logic: only show the company name in certain conditions
219
+ if (
220
+ (len(valuation) > 0 and valuation[0] is not None and valuation[0] > 10) # Check if > 10B
221
+ or (len(filtered_data) < 15)
222
+ or (node in filtered_data.nlargest(5, "Valuation_Billions")["Company"].tolist())
223
+ ):
224
+ node_text.append(node) # Show just the company name
225
+ node_textposition.append('bottom center')
226
+ else:
227
+ node_text.append("") # Hide company label
228
+ node_textposition.append('bottom center')
229
+
230
+ node_trace = go.Scatter(
231
+ x=node_x,
232
+ y=node_y,
233
+ text=node_text,
234
+ textposition=node_textposition,
235
+ mode='markers+text',
236
+ hoverinfo='text',
237
+ hovertext=node_hovertext,
238
+ textfont=dict(size=13), # Adjust label font size
239
+ marker=dict(
240
+ showscale=False,
241
+ size=node_size,
242
+ color=node_color,
243
+ line=dict(width=0.5, color='#333333')
244
+ )
245
+ )
246
+
247
+ # Compute total market cap
248
+ total_market_cap = filtered_data["Valuation_Billions"].sum()
249
+
250
+ fig = go.Figure(data=[edge_trace, node_trace])
251
+
252
+ fig.update_layout(
253
+ title="",
254
+ titlefont_size=28,
255
+ margin=dict(l=20, r=20, t=60, b=20),
256
+ hovermode='closest',
257
+ width=1200,
258
+ height=800,
259
+ autosize=True,
260
+ xaxis=dict(showgrid=False, zeroline=False, visible=False),
261
+ yaxis=dict(showgrid=False, zeroline=False, visible=False),
262
+ showlegend=False, # Hide the legend to maximize space
263
+ annotations=[
264
+ dict(
265
+ x=0.5, y=1.1, xref='paper', yref='paper',
266
+ text=f"Combined Market Cap: ${total_market_cap:.1f} Billions",
267
+ showarrow=False, font=dict(size=14), xanchor='center'
268
+ )
269
+ ]
270
  )
271
 
272
+ return fig
273
+
274
+
275
+ # Gradio app function
276
+ def app(
277
+ selected_country,
278
+ selected_industry,
279
+ selected_company,
280
+ selected_investors,
281
+ exclude_countries,
282
+ exclude_industries,
283
+ exclude_companies,
284
+ exclude_investors,
285
+ selected_valuation_range
286
+ ):
287
+ investors, filtered_data = filter_investors(
288
+ selected_country,
289
+ selected_industry,
290
+ selected_investors,
291
+ selected_company,
292
+ exclude_countries,
293
+ exclude_industries,
294
+ exclude_companies,
295
+ exclude_investors,
296
+ selected_valuation_range
297
  )
298
+ if not investors:
299
+ return go.Figure()
300
+ # NEW: Pass valuation_range to generate_graph
301
+ graph = generate_graph(investors, filtered_data, selected_valuation_range)
302
+ return graph
303
+
304
+
305
+ def main():
306
+ country_list = ["All"] + sorted(data["Country"].dropna().unique())
307
+ industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
308
+ company_list = ["All"] + sorted(data["Company"].dropna().unique())
309
+ investor_list = sorted(investor_company_mapping.keys())
310
+
311
+ # Valuation range choices
312
+ valuation_ranges = ["All", "1-5", "5-10", "10-15", "15-20", "20+"]
313
+
314
+ with gr.Blocks(title="Venture Networks Visualization") as demo:
315
+ gr.Markdown("# Venture Networks Visualization")
316
+
317
+ with gr.Row():
318
+ country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
319
+ industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
320
+ company_filter = gr.Dropdown(choices=company_list, label="Company", value="All")
321
+ investor_filter = gr.Dropdown(choices=investor_list, label="Select Investors", value=[], multiselect=True)
322
+
323
+ with gr.Row():
324
+ valuation_range_filter = gr.Dropdown(
325
+ choices=valuation_ranges,
326
+ label="Valuation Range (Billions)",
327
+ value="All"
328
+ )
329
+ exclude_country_filter = gr.Dropdown(choices=country_list[1:], label="Exclude Country", value=[], multiselect=True)
330
+ exclude_industry_filter = gr.Dropdown(choices=industry_list[1:], label="Exclude Industry", value=[], multiselect=True)
331
+ exclude_company_filter = gr.Dropdown(choices=company_list[1:], label="Exclude Company", value=[], multiselect=True)
332
+ exclude_investor_filter = gr.Dropdown(choices=investor_list, label="Exclude Investors", value=[], multiselect=True)
333
+
334
+ graph_output = gr.Plot(label="Network Graph")
335
+
336
+ inputs = [
337
+ country_filter,
338
+ industry_filter,
339
+ company_filter,
340
+ investor_filter,
341
+ exclude_country_filter,
342
+ exclude_industry_filter,
343
+ exclude_company_filter,
344
+ exclude_investor_filter,
345
+ valuation_range_filter
346
+ ]
347
+ outputs = [graph_output]
348
+
349
+ # Set up event triggers for all inputs
350
+ for input_control in inputs:
351
+ input_control.change(app, inputs, outputs)
352
+
353
+ gr.Markdown("**Instructions:** Use the dropdowns to filter the network graph. For valuation ranges 15–20 or 20+, you’ll see each company's info label without hovering.")
354
+
355
+ demo.launch()
356
+
357
+
358
+ if __name__ == "__main__":
359
+ main()