LeonceNsh commited on
Commit
3e1452c
·
verified ·
1 Parent(s): e6f1a9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -187
app.py CHANGED
@@ -4,18 +4,18 @@ import plotly.graph_objects as go
4
  import gradio as gr
5
  import logging
6
 
7
- # -------------------- Setup Logging --------------------
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
- # -------------------- Load and Preprocess Data --------------------
12
- FILE_PATH = "cbinsights_data.csv" # Replace with your actual file path
13
 
14
  try:
15
- data = pd.read_csv(FILE_PATH, skiprows=1)
16
  logger.info("CSV file loaded successfully.")
17
  except FileNotFoundError:
18
- logger.error(f"File not found: {FILE_PATH}")
19
  raise
20
  except Exception as e:
21
  logger.error(f"Error loading CSV file: {e}")
@@ -35,8 +35,8 @@ valuation_column = valuation_columns[0]
35
  logger.info(f"Using valuation column: {valuation_column}")
36
 
37
  # Clean and prepare data
38
- data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
39
- data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
40
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
41
  data.rename(columns={
42
  "company": "Company",
@@ -49,11 +49,8 @@ data.rename(columns={
49
 
50
  logger.info("Data cleaned and columns renamed.")
51
 
52
- # -------------------- Build Investor-Company Mapping --------------------
53
  def build_investor_company_mapping(df):
54
- """
55
- Builds a mapping from investors to the companies they've invested in.
56
- """
57
  mapping = {}
58
  for _, row in df.iterrows():
59
  company = row["Company"]
@@ -68,48 +65,30 @@ def build_investor_company_mapping(df):
68
  investor_company_mapping = build_investor_company_mapping(data)
69
  logger.info("Investor to company mapping created.")
70
 
71
- # -------------------- Filter Investors --------------------
72
  def filter_investors(selected_country, selected_industry, selected_investors):
73
- """
74
- Filters the dataset based on selected country, industry, and investors.
75
- """
76
  filtered_data = data.copy()
77
  if selected_country != "All":
78
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
79
  if selected_industry != "All":
80
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
81
  if selected_investors:
82
- pattern = '|'.join(selected_investors)
83
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
84
-
85
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
86
  filtered_investors = list(investor_company_mapping_filtered.keys())
87
  return filtered_investors, filtered_data
88
 
89
- # -------------------- Generate Network Graph --------------------
90
  def generate_graph(investors, filtered_data):
91
- """
92
- Generates an interactive network graph using Plotly.
93
- """
94
  if not investors:
95
  logger.warning("No investors selected.")
96
- # Create an empty figure with a message
97
- fig = go.Figure()
98
- fig.update_layout(
99
- title="No data available for the selected filters.",
100
- xaxis=dict(visible=False),
101
- yaxis=dict(visible=False),
102
- annotations=[dict(
103
- text="Please adjust your filters to display the network graph.",
104
- showarrow=False,
105
- xref="paper", yref="paper",
106
- x=0.5, y=0.5,
107
- font=dict(size=20)
108
- )]
109
- )
110
- return fig
111
 
112
- # Define a color-blind friendly palette
 
 
113
  color_palette = [
114
  "#377eb8", # Blue
115
  "#e41a1c", # Red
@@ -119,216 +98,200 @@ def generate_graph(investors, filtered_data):
119
  "#ffff33", # Yellow
120
  "#a65628", # Brown
121
  "#f781bf", # Pink
 
122
  ]
123
-
124
- # Assign colors to investors
125
- unique_investors = investors
126
- num_colors = len(unique_investors)
127
- color_palette_extended = color_palette * (num_colors // len(color_palette) + 1)
128
- investor_color_map = {investor: color_palette_extended[i] for i, investor in enumerate(unique_investors)}
129
-
130
- # Create the graph
131
  G = nx.Graph()
132
  for investor in investors:
133
- companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
134
  for company in companies:
135
- G.add_node(company, type='company', valuation=filtered_data.loc[filtered_data["Company"] == company, "Valuation_Billions"].values[0])
136
- G.add_node(investor, type='investor')
137
  G.add_edge(investor, company)
138
-
139
- # Position nodes using spring layout
140
- pos = nx.spring_layout(G, seed=42, k=0.5)
141
-
142
- # Prepare edge traces
143
  edge_x = []
144
  edge_y = []
 
145
  for edge in G.edges():
146
  x0, y0 = pos[edge[0]]
147
  x1, y1 = pos[edge[1]]
148
- edge_x += [x0, x1, None]
149
- edge_y += [y0, y1, None]
150
-
151
  edge_trace = go.Scatter(
152
  x=edge_x,
153
  y=edge_y,
154
- line=dict(width=0.5, color='#888'),
155
  hoverinfo='none',
156
  mode='lines'
157
  )
158
-
159
- # Prepare node traces
160
  node_x = []
161
  node_y = []
162
  node_text = []
163
  node_color = []
164
  node_size = []
165
- node_type = []
166
-
167
- for node in G.nodes(data=True):
168
- x, y = pos[node[0]]
169
  node_x.append(x)
170
  node_y.append(y)
171
- node_type.append(node[1]['type'])
172
- if node[1]['type'] == 'investor':
173
- node_text.append(node[0]) # Investor labels
174
- node_color.append(investor_color_map[node[0]])
175
- node_size.append(20) # Fixed size for investors
 
176
  else:
177
- valuation = node[1]['valuation']
178
- size = (valuation * 5) if pd.notnull(valuation) else 10 # Scale size
179
- size = max(size, 5) # Minimum size
 
 
 
 
 
 
180
  node_size.append(size)
181
- node_text.append(f"{node[0]}<br>Valuation: ${valuation}B" if pd.notnull(valuation) else f"{node[0]}<br>Valuation: N/A")
182
- node_color.append("#b2df8a") # Light green for companies
183
-
 
 
 
 
 
 
184
  node_trace = go.Scatter(
185
  x=node_x,
186
  y=node_y,
187
- mode='markers+text',
188
  text=node_text,
189
- textposition="top center",
190
  hoverinfo='text',
 
191
  marker=dict(
192
  showscale=False,
193
- color=node_color,
194
  size=node_size,
195
- line=dict(width=1, color='white')
196
- )
 
 
 
197
  )
198
-
199
- # Create the figure
200
- fig = go.Figure(data=[edge_trace, node_trace])
201
-
202
  # Add legend manually
203
- investor_colors = list(investor_color_map.values())[:8] # Limit to first 8 for legend
204
- investor_names = list(investor_color_map.keys())[:8]
205
-
206
- for i, investor in enumerate(investor_names):
207
- fig.add_trace(go.Scatter(
208
- x=[None],
209
- y=[None],
210
- mode='markers',
211
- marker=dict(
212
- size=10,
213
- color=investor_color_map[investor]
214
- ),
215
- legendgroup='Investors',
216
- showlegend=True,
217
- name=investor
218
- ))
219
-
220
- # Update layout for better aesthetics
221
  fig.update_layout(
222
- title={
223
- 'text': "Venture Networks",
224
- 'y':0.95,
225
- 'x':0.5,
226
- 'xanchor': 'center',
227
- 'yanchor': 'top'
228
- },
229
  titlefont_size=24,
230
- showlegend=True,
231
- legend=dict(
232
- title="Top Investors",
233
- itemsizing='constant',
234
- itemclick='toggleothers',
235
- itemdoubleclick='toggle',
236
- font=dict(size=10)
237
- ),
238
- margin=dict(l=40, r=40, t=80, b=40),
239
  hovermode='closest',
240
- width=1200,
241
  height=800,
242
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
243
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
 
 
 
244
  )
245
-
 
 
 
 
 
 
 
246
  return fig
247
 
248
- # -------------------- Gradio Application --------------------
249
  def app(selected_country, selected_industry, selected_investors):
250
- """
251
- Main application function that filters data and generates the network graph.
252
- """
253
  investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
 
 
254
  graph = generate_graph(investors, filtered_data)
255
- return ', '.join(investors) if investors else "No investors found.", graph
256
 
 
257
  def main():
258
- """
259
- Initializes and launches the Gradio interface.
260
- """
261
- # Prepare dropdown options
262
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
263
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
264
  investor_list = sorted(investor_company_mapping.keys())
265
-
266
- # Define Gradio Blocks
267
- with gr.Blocks(css="""
268
- .gradio-container {
269
- background-color: #f9f9f9;
270
- padding: 20px;
271
- }
272
- .gradio-row {
273
- justify-content: center;
274
- }
275
- """) as demo:
276
  gr.Markdown("""
277
  # Venture Networks Visualization
278
-
279
- Explore the relationships between investors and companies across different countries and industries. Use the filters below to customize the network graph.
280
-
281
- **Instructions:**
282
- - Select a country and/or industry to filter the data.
283
- - Choose one or more investors to focus on specific investment activities.
284
- - Hover over company nodes to view their valuations.
285
  """)
286
-
287
  with gr.Row():
288
- with gr.Column(scale=1):
289
- country_filter = gr.Dropdown(
290
- choices=country_list,
291
- label="Country",
292
- value="All",
293
- info="Select a country to filter the data."
294
- )
295
- industry_filter = gr.Dropdown(
296
- choices=industry_list,
297
- label="Industry",
298
- value="All",
299
- info="Select an industry to filter the data."
300
- )
301
- investor_filter = gr.Dropdown(
302
- choices=investor_list,
303
- label="Investor",
304
- value=[],
305
- multiselect=True,
306
- info="Select one or more investors to focus on their investments."
307
- )
308
- reset_button = gr.Button("Reset Filters", variant="secondary")
309
- with gr.Column(scale=3):
310
- graph_output = gr.Plot(label="Network Graph")
311
-
312
  with gr.Row():
313
  investor_output = gr.Textbox(label="Filtered Investors", interactive=False)
314
-
315
- # Define Inputs and Outputs
316
  inputs = [country_filter, industry_filter, investor_filter]
317
  outputs = [investor_output, graph_output]
318
-
319
- # Define Event Handlers
320
- country_filter.change(fn=app, inputs=inputs, outputs=outputs)
321
- industry_filter.change(fn=app, inputs=inputs, outputs=outputs)
322
- investor_filter.change(fn=app, inputs=inputs, outputs=outputs)
323
- reset_button.click(fn=lambda: ["", go.Figure()], inputs=None, outputs=outputs)
324
-
325
- # Add Footer
326
  gr.Markdown("""
327
- ---
328
- © 2024 Venture Networks Visualization Tool
 
 
 
 
 
 
 
329
  """)
330
-
331
- # Launch the Gradio app
332
  demo.launch()
333
 
334
  if __name__ == "__main__":
 
4
  import gradio as gr
5
  import logging
6
 
7
+ # Set up logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
+ # Load and preprocess the dataset
12
+ file_path = "cbinsights_data.csv" # Replace with your actual file path
13
 
14
  try:
15
+ data = pd.read_csv(file_path, skiprows=1)
16
  logger.info("CSV file loaded successfully.")
17
  except FileNotFoundError:
18
+ logger.error(f"File not found: {file_path}")
19
  raise
20
  except Exception as e:
21
  logger.error(f"Error loading CSV file: {e}")
 
35
  logger.info(f"Using valuation column: {valuation_column}")
36
 
37
  # Clean and prepare data
38
+ data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
39
+ data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
40
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
41
  data.rename(columns={
42
  "company": "Company",
 
49
 
50
  logger.info("Data cleaned and columns renamed.")
51
 
52
+ # Build investor-company mapping
53
  def build_investor_company_mapping(df):
 
 
 
54
  mapping = {}
55
  for _, row in df.iterrows():
56
  company = row["Company"]
 
65
  investor_company_mapping = build_investor_company_mapping(data)
66
  logger.info("Investor to company mapping created.")
67
 
68
+ # Filter investors by country, industry, and investor selection
69
  def filter_investors(selected_country, selected_industry, selected_investors):
 
 
 
70
  filtered_data = data.copy()
71
  if selected_country != "All":
72
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
73
  if selected_industry != "All":
74
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
75
  if selected_investors:
76
+ pattern = '|'.join([re.escape(inv) for inv in selected_investors])
77
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
78
+
79
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
80
  filtered_investors = list(investor_company_mapping_filtered.keys())
81
  return filtered_investors, filtered_data
82
 
83
+ # Generate Plotly graph
84
  def generate_graph(investors, filtered_data):
 
 
 
85
  if not investors:
86
  logger.warning("No investors selected.")
87
+ return go.Figure()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Create a color map for investors
90
+ unique_investors = investors
91
+ num_colors = len(unique_investors)
92
  color_palette = [
93
  "#377eb8", # Blue
94
  "#e41a1c", # Red
 
98
  "#ffff33", # Yellow
99
  "#a65628", # Brown
100
  "#f781bf", # Pink
101
+ "#999999", # Grey
102
  ]
103
+ # Extend color_palette if necessary
104
+ while num_colors > len(color_palette):
105
+ color_palette.extend(color_palette)
106
+
107
+ investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
108
+
 
 
109
  G = nx.Graph()
110
  for investor in investors:
111
+ companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
112
  for company in companies:
113
+ G.add_node(company)
114
+ G.add_node(investor)
115
  G.add_edge(investor, company)
116
+
117
+ pos = nx.spring_layout(G, seed=42)
 
 
 
118
  edge_x = []
119
  edge_y = []
120
+
121
  for edge in G.edges():
122
  x0, y0 = pos[edge[0]]
123
  x1, y1 = pos[edge[1]]
124
+ edge_x.extend([x0, x1, None])
125
+ edge_y.extend([y0, y1, None])
126
+
127
  edge_trace = go.Scatter(
128
  x=edge_x,
129
  y=edge_y,
130
+ line=dict(width=0.5, color='#aaaaaa'),
131
  hoverinfo='none',
132
  mode='lines'
133
  )
134
+
 
135
  node_x = []
136
  node_y = []
137
  node_text = []
138
  node_color = []
139
  node_size = []
140
+ node_hovertext = []
141
+
142
+ for node in G.nodes():
143
+ x, y = pos[node]
144
  node_x.append(x)
145
  node_y.append(y)
146
+ if node in investors:
147
+ # Investor node
148
+ node_text.append(node) # Label investors
149
+ node_color.append(investor_color_map[node]) # Color assigned to investor
150
+ node_size.append(30) # Fixed size for investors
151
+ node_hovertext.append(f"Investor: {node}")
152
  else:
153
+ # Company node
154
+ valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
155
+ industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
156
+ if len(valuation) > 0 and not pd.isnull(valuation[0]):
157
+ size = valuation[0] * 5 # Scale size as needed
158
+ if size < 10:
159
+ size = 10 # Minimum size
160
+ else:
161
+ size = 15 # Default size
162
  node_size.append(size)
163
+ node_text.append("") # Hide company labels by default
164
+ node_color.append("#a6d854") # Light green color for companies
165
+ hovertext = f"Company: {node}"
166
+ if len(industry) > 0 and not pd.isnull(industry[0]):
167
+ hovertext += f"<br>Industry: {industry[0]}"
168
+ if len(valuation) > 0 and not pd.isnull(valuation[0]):
169
+ hovertext += f"<br>Valuation: ${valuation[0]}B"
170
+ node_hovertext.append(hovertext)
171
+
172
  node_trace = go.Scatter(
173
  x=node_x,
174
  y=node_y,
 
175
  text=node_text,
176
+ mode='markers+text',
177
  hoverinfo='text',
178
+ hovertext=node_hovertext,
179
  marker=dict(
180
  showscale=False,
 
181
  size=node_size,
182
+ color=node_color,
183
+ line=dict(width=0.5, color='#333333')
184
+ ),
185
+ textposition="middle center",
186
+ textfont=dict(size=12, color="#000000")
187
  )
188
+
 
 
 
189
  # Add legend manually
190
+ legend_items = []
191
+ for investor in unique_investors:
192
+ legend_items.append(
193
+ go.Scatter(
194
+ x=[None],
195
+ y=[None],
196
+ mode='markers',
197
+ marker=dict(
198
+ size=10,
199
+ color=investor_color_map[investor]
200
+ ),
201
+ legendgroup=investor,
202
+ showlegend=True,
203
+ name=investor
204
+ )
205
+ )
206
+
207
+ fig = go.Figure(data=legend_items + [edge_trace, node_trace])
208
  fig.update_layout(
209
+ title="Venture Networks",
 
 
 
 
 
 
210
  titlefont_size=24,
211
+ margin=dict(l=20, r=20, t=60, b=20),
 
 
 
 
 
 
 
 
212
  hovermode='closest',
213
+ width=1000,
214
  height=800,
215
+ legend=dict(
216
+ title="Investors",
217
+ font=dict(size=12),
218
+ itemsizing='constant'
219
+ )
220
  )
221
+
222
+ # Improve layout responsiveness
223
+ fig.update_layout(
224
+ autosize=True,
225
+ xaxis={'showgrid': False, 'zeroline': False, 'visible': False},
226
+ yaxis={'showgrid': False, 'zeroline': False, 'visible': False}
227
+ )
228
+
229
  return fig
230
 
231
+ # Gradio app
232
  def app(selected_country, selected_industry, selected_investors):
 
 
 
233
  investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
234
+ if not investors:
235
+ return "No investors found with the selected filters.", go.Figure()
236
  graph = generate_graph(investors, filtered_data)
237
+ return ', '.join(investors), graph
238
 
239
+ # Main function
240
  def main():
241
+ import re # Added import for regex
 
 
 
242
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
243
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
244
  investor_list = sorted(investor_company_mapping.keys())
245
+
246
+ with gr.Blocks(title="Venture Networks Visualization") as demo:
 
 
 
 
 
 
 
 
 
247
  gr.Markdown("""
248
  # Venture Networks Visualization
249
+ Explore the connections between investors and companies in the venture capital ecosystem. Use the filters below to customize the network graph.
 
 
 
 
 
 
250
  """)
 
251
  with gr.Row():
252
+ country_filter = gr.Dropdown(
253
+ choices=country_list,
254
+ label="Country",
255
+ value="All",
256
+ info="Filter companies by country."
257
+ )
258
+ industry_filter = gr.Dropdown(
259
+ choices=industry_list,
260
+ label="Industry",
261
+ value="All",
262
+ info="Filter companies by industry."
263
+ )
264
+ investor_filter = gr.Dropdown(
265
+ choices=investor_list,
266
+ label="Select Investors",
267
+ value=[],
268
+ multiselect=True,
269
+ info="Select one or more investors to visualize."
270
+ )
 
 
 
 
 
271
  with gr.Row():
272
  investor_output = gr.Textbox(label="Filtered Investors", interactive=False)
273
+ graph_output = gr.Plot(label="Network Graph")
274
+
275
  inputs = [country_filter, industry_filter, investor_filter]
276
  outputs = [investor_output, graph_output]
277
+
278
+ # Update the graph when any filter changes
279
+ country_filter.change(app, inputs, outputs)
280
+ industry_filter.change(app, inputs, outputs)
281
+ investor_filter.change(app, inputs, outputs)
282
+
 
 
283
  gr.Markdown("""
284
+ **Instructions:**
285
+ - **Country**: Select a country to filter companies based on their location.
286
+ - **Industry**: Choose an industry to focus on companies within that sector.
287
+ - **Select Investors**: Pick one or more investors to visualize their network.
288
+
289
+ **Graph Interaction:**
290
+ - **Zoom**: Use your mouse wheel or trackpad to zoom in and out.
291
+ - **Pan**: Click and drag to move around the graph.
292
+ - **Hover**: Hover over nodes to see more information.
293
  """)
294
+
 
295
  demo.launch()
296
 
297
  if __name__ == "__main__":