LeonceNsh commited on
Commit
b954334
·
verified ·
1 Parent(s): 05d82ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -75
app.py CHANGED
@@ -26,8 +26,7 @@ except Exception as e:
26
  data.columns = data.columns.str.strip().str.lower()
27
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
 
29
- # Filter out Health since Healthcare is the correct Market Sergment
30
- print(data.head())
31
  data = data[data.industry != 'Health']
32
 
33
  # Identify the valuation column
@@ -97,21 +96,14 @@ def generate_graph(investors, filtered_data):
97
  unique_investors = investors
98
  num_colors = len(unique_investors)
99
  color_palette = [
100
- "#377eb8", # Blue
101
- "#e41a1c", # Red
102
- "#4daf4a", # Green
103
- "#984ea3", # Purple
104
- "#ff7f00", # Orange
105
- "#ffff33", # Yellow
106
- "#a65628", # Brown
107
- "#f781bf", # Pink
108
- "#999999", # Grey
109
  ]
110
  while num_colors > len(color_palette):
111
  color_palette.extend(color_palette)
112
-
113
  investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
114
-
115
  G = nx.Graph()
116
  for investor in investors:
117
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
@@ -119,17 +111,16 @@ def generate_graph(investors, filtered_data):
119
  G.add_node(company)
120
  G.add_node(investor)
121
  G.add_edge(investor, company)
122
-
123
  pos = nx.spring_layout(G, seed=42)
124
- edge_x = []
125
- edge_y = []
126
-
127
  for edge in G.edges():
128
  x0, y0 = pos[edge[0]]
129
  x1, y1 = pos[edge[1]]
130
  edge_x.extend([x0, x1, None])
131
  edge_y.extend([y0, y1, None])
132
-
133
  edge_trace = go.Scatter(
134
  x=edge_x,
135
  y=edge_y,
@@ -137,47 +128,38 @@ def generate_graph(investors, filtered_data):
137
  hoverinfo='none',
138
  mode='lines'
139
  )
140
-
141
- node_x = []
142
- node_y = []
143
- node_text = []
144
- node_color = []
145
- node_size = []
146
- node_hovertext = []
147
-
148
  for node in G.nodes():
149
  x, y = pos[node]
150
  node_x.append(x)
151
  node_y.append(y)
152
  if node in investors:
153
- node_text.append(node)
154
  node_color.append(investor_color_map[node])
155
  node_size.append(30)
156
  node_hovertext.append(f"Investor: {node}")
157
  else:
158
  valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
159
  industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
160
- if len(valuation) > 0 and not pd.isnull(valuation[0]):
161
- size = valuation[0] * 5
162
- if size < 10:
163
- size = 10
164
- else:
165
- size = 15
166
- node_size.append(size)
167
- node_text.append("")
168
  node_color.append("#a6d854")
169
  hovertext = f"Company: {node}"
170
  if len(industry) > 0 and not pd.isnull(industry[0]):
171
  hovertext += f"<br>Industry: {industry[0]}"
172
  if len(valuation) > 0 and not pd.isnull(valuation[0]):
173
- hovertext += f"<br>Valuation: ${valuation[0]}B"
174
  node_hovertext.append(hovertext)
175
-
176
  node_trace = go.Scatter(
177
  x=node_x,
178
  y=node_y,
179
  text=node_text,
180
- mode='markers+text',
181
  hoverinfo='text',
182
  hovertext=node_hovertext,
183
  marker=dict(
@@ -185,28 +167,20 @@ def generate_graph(investors, filtered_data):
185
  size=node_size,
186
  color=node_color,
187
  line=dict(width=0.5, color='#333333')
188
- ),
189
- textposition="middle center",
190
- textfont=dict(size=12, color="#000000")
191
- )
192
-
193
- legend_items = []
194
- for investor in unique_investors:
195
- legend_items.append(
196
- go.Scatter(
197
- x=[None],
198
- y=[None],
199
- mode='markers',
200
- marker=dict(
201
- size=10,
202
- color=investor_color_map[investor]
203
- ),
204
- legendgroup=investor,
205
- showlegend=True,
206
- name=investor
207
- )
208
  )
209
-
 
 
 
 
 
 
 
 
 
 
 
 
210
  fig = go.Figure(data=legend_items + [edge_trace, node_trace])
211
  fig.update_layout(
212
  title="Venture Networks",
@@ -214,15 +188,19 @@ def generate_graph(investors, filtered_data):
214
  margin=dict(l=20, r=20, t=60, b=20),
215
  hovermode='closest',
216
  width=1200,
217
- height=800
218
- )
219
-
220
- fig.update_layout(
221
  autosize=True,
222
  xaxis={'showgrid': False, 'zeroline': False, 'visible': False},
223
- yaxis={'showgrid': False, 'zeroline': False, 'visible': False}
 
 
 
 
 
 
 
224
  )
225
-
226
  return fig
227
 
228
  # Gradio app
@@ -233,18 +211,14 @@ def app(selected_country, selected_industry, selected_company, selected_investor
233
  graph = generate_graph(investors, filtered_data)
234
  return ', '.join(investors), graph
235
 
236
- # Main function
237
  def main():
238
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
239
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
240
  company_list = ["All"] + sorted(data["Company"].dropna().unique())
241
  investor_list = sorted(investor_company_mapping.keys())
242
-
243
  with gr.Blocks(title="Venture Networks Visualization") as demo:
244
- gr.Markdown("""
245
- # Venture Networks Visualization
246
- Explore the connections between investors and companies in the venture capital ecosystem. Use the filters below to customize the network graph.
247
- """)
248
  with gr.Row():
249
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
250
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
@@ -261,10 +235,8 @@ def main():
261
  industry_filter.change(app, inputs, outputs)
262
  company_filter.change(app, inputs, outputs)
263
  investor_filter.change(app, inputs, outputs)
264
-
265
- gr.Markdown("""
266
- **Instructions:** Use the dropdowns to filter the network graph.
267
- """)
268
 
269
  demo.launch()
270
 
 
26
  data.columns = data.columns.str.strip().str.lower()
27
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
 
29
+ # Filter out Health since Healthcare is the correct Market Segment
 
30
  data = data[data.industry != 'Health']
31
 
32
  # Identify the valuation column
 
96
  unique_investors = investors
97
  num_colors = len(unique_investors)
98
  color_palette = [
99
+ "#377eb8", "#e41a1c", "#4daf4a", "#984ea3",
100
+ "#ff7f00", "#ffff33", "#a65628", "#f781bf", "#999999"
 
 
 
 
 
 
 
101
  ]
102
  while num_colors > len(color_palette):
103
  color_palette.extend(color_palette)
104
+
105
  investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
106
+
107
  G = nx.Graph()
108
  for investor in investors:
109
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
 
111
  G.add_node(company)
112
  G.add_node(investor)
113
  G.add_edge(investor, company)
114
+
115
  pos = nx.spring_layout(G, seed=42)
116
+ edge_x, edge_y = [], []
117
+
 
118
  for edge in G.edges():
119
  x0, y0 = pos[edge[0]]
120
  x1, y1 = pos[edge[1]]
121
  edge_x.extend([x0, x1, None])
122
  edge_y.extend([y0, y1, None])
123
+
124
  edge_trace = go.Scatter(
125
  x=edge_x,
126
  y=edge_y,
 
128
  hoverinfo='none',
129
  mode='lines'
130
  )
131
+
132
+ node_x, node_y, node_text = [], [], []
133
+ node_color, node_size, node_hovertext = [], [], []
134
+
 
 
 
 
135
  for node in G.nodes():
136
  x, y = pos[node]
137
  node_x.append(x)
138
  node_y.append(y)
139
  if node in investors:
140
+ node_text.append("") # Remove investor labels
141
  node_color.append(investor_color_map[node])
142
  node_size.append(30)
143
  node_hovertext.append(f"Investor: {node}")
144
  else:
145
  valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
146
  industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
147
+ size = valuation[0] * 5 if len(valuation) > 0 and not pd.isnull(valuation[0]) else 15
148
+ node_size.append(max(size, 10))
149
+ node_text.append("") # No text label for companies
 
 
 
 
 
150
  node_color.append("#a6d854")
151
  hovertext = f"Company: {node}"
152
  if len(industry) > 0 and not pd.isnull(industry[0]):
153
  hovertext += f"<br>Industry: {industry[0]}"
154
  if len(valuation) > 0 and not pd.isnull(valuation[0]):
155
+ hovertext += f"<br>Valuation: ${valuation[0]:.2f}B"
156
  node_hovertext.append(hovertext)
157
+
158
  node_trace = go.Scatter(
159
  x=node_x,
160
  y=node_y,
161
  text=node_text,
162
+ mode='markers',
163
  hoverinfo='text',
164
  hovertext=node_hovertext,
165
  marker=dict(
 
167
  size=node_size,
168
  color=node_color,
169
  line=dict(width=0.5, color='#333333')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
+ )
172
+
173
+ legend_items = [
174
+ go.Scatter(
175
+ x=[None], y=[None], mode='markers',
176
+ marker=dict(size=10, color=investor_color_map[investor]),
177
+ legendgroup=investor, showlegend=True, name=investor
178
+ ) for investor in unique_investors
179
+ ]
180
+
181
+ # Compute total market cap
182
+ total_market_cap = filtered_data["Valuation_Billions"].sum()
183
+
184
  fig = go.Figure(data=legend_items + [edge_trace, node_trace])
185
  fig.update_layout(
186
  title="Venture Networks",
 
188
  margin=dict(l=20, r=20, t=60, b=20),
189
  hovermode='closest',
190
  width=1200,
191
+ height=800,
 
 
 
192
  autosize=True,
193
  xaxis={'showgrid': False, 'zeroline': False, 'visible': False},
194
+ yaxis={'showgrid': False, 'zeroline': False, 'visible': False},
195
+ annotations=[
196
+ dict(
197
+ x=0.5, y=1.1, xref='paper', yref='paper',
198
+ text=f"Total Market Cap of Companies: ${total_market_cap:.2f}B",
199
+ showarrow=False, font=dict(size=14), xanchor='center'
200
+ )
201
+ ]
202
  )
203
+
204
  return fig
205
 
206
  # Gradio app
 
211
  graph = generate_graph(investors, filtered_data)
212
  return ', '.join(investors), graph
213
 
 
214
  def main():
215
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
216
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
217
  company_list = ["All"] + sorted(data["Company"].dropna().unique())
218
  investor_list = sorted(investor_company_mapping.keys())
219
+
220
  with gr.Blocks(title="Venture Networks Visualization") as demo:
221
+ gr.Markdown("# Venture Networks Visualization")
 
 
 
222
  with gr.Row():
223
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
224
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
 
235
  industry_filter.change(app, inputs, outputs)
236
  company_filter.change(app, inputs, outputs)
237
  investor_filter.change(app, inputs, outputs)
238
+
239
+ gr.Markdown("**Instructions:** Use the dropdowns to filter the network graph.")
 
 
240
 
241
  demo.launch()
242