LeonceNsh commited on
Commit
e6f1a9e
·
verified ·
1 Parent(s): 47c5bfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -79
app.py CHANGED
@@ -4,18 +4,18 @@ import plotly.graph_objects as go
4
  import gradio as gr
5
  import logging
6
 
7
- # Set up logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
- # Load and preprocess the dataset
12
- file_path = "cbinsights_data.csv" # Replace with your actual file path
13
 
14
  try:
15
- data = pd.read_csv(file_path, skiprows=1)
16
  logger.info("CSV file loaded successfully.")
17
  except FileNotFoundError:
18
- logger.error(f"File not found: {file_path}")
19
  raise
20
  except Exception as e:
21
  logger.error(f"Error loading CSV file: {e}")
@@ -35,8 +35,8 @@ valuation_column = valuation_columns[0]
35
  logger.info(f"Using valuation column: {valuation_column}")
36
 
37
  # Clean and prepare data
38
- data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
39
- data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
40
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
41
  data.rename(columns={
42
  "company": "Company",
@@ -49,8 +49,11 @@ data.rename(columns={
49
 
50
  logger.info("Data cleaned and columns renamed.")
51
 
52
- # Build investor-company mapping
53
  def build_investor_company_mapping(df):
 
 
 
54
  mapping = {}
55
  for _, row in df.iterrows():
56
  company = row["Company"]
@@ -65,8 +68,11 @@ def build_investor_company_mapping(df):
65
  investor_company_mapping = build_investor_company_mapping(data)
66
  logger.info("Investor to company mapping created.")
67
 
68
- # Filter investors by country, industry, and investor selection
69
  def filter_investors(selected_country, selected_industry, selected_investors):
 
 
 
70
  filtered_data = data.copy()
71
  if selected_country != "All":
72
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
@@ -75,144 +81,255 @@ def filter_investors(selected_country, selected_industry, selected_investors):
75
  if selected_investors:
76
  pattern = '|'.join(selected_investors)
77
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
78
-
79
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
80
  filtered_investors = list(investor_company_mapping_filtered.keys())
81
  return filtered_investors, filtered_data
82
 
83
- # Generate Plotly graph
84
  def generate_graph(investors, filtered_data):
 
 
 
85
  if not investors:
86
  logger.warning("No investors selected.")
87
- return go.Figure()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Create a color map for investors
90
- unique_investors = investors
91
- num_colors = len(unique_investors)
92
  color_palette = [
93
- "#000000", # black
94
- "#E69F00", # orange
95
- "#56B4E9", # sky blue
96
- "#009E73", # bluish green
97
- "#F0E442", # yellow
98
- "#0072B2", # blue
99
- "#D55E00", # vermillion
100
- "#CC79A7", # reddish purple
101
  ]
102
- # Extend color_palette if necessary
103
- while num_colors > len(color_palette):
104
- color_palette.extend(color_palette)
105
-
106
- investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
107
-
 
 
108
  G = nx.Graph()
109
  for investor in investors:
110
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
111
  for company in companies:
112
- G.add_node(company)
113
- G.add_node(investor)
114
  G.add_edge(investor, company)
115
-
116
- pos = nx.spring_layout(G, seed=42)
 
 
 
117
  edge_x = []
118
  edge_y = []
119
-
120
  for edge in G.edges():
121
  x0, y0 = pos[edge[0]]
122
  x1, y1 = pos[edge[1]]
123
- edge_x.extend([x0, x1, None])
124
- edge_y.extend([y0, y1, None])
125
-
126
  edge_trace = go.Scatter(
127
  x=edge_x,
128
  y=edge_y,
129
- line=dict(width=1, color='#888'),
130
  hoverinfo='none',
131
  mode='lines'
132
  )
133
-
 
134
  node_x = []
135
  node_y = []
136
  node_text = []
137
  node_color = []
138
  node_size = []
139
-
140
- for node in G.nodes():
141
- x, y = pos[node]
 
142
  node_x.append(x)
143
  node_y.append(y)
144
- if node in investors:
145
- # Investor node
146
- node_text.append(node) # Label investors
147
- node_color.append(investor_color_map[node]) # Color assigned to investor
148
  node_size.append(20) # Fixed size for investors
149
  else:
150
- # Company node
151
- valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
152
- if len(valuation) > 0 and not pd.isnull(valuation[0]):
153
- size = valuation[0] * 5 # Scale size as needed
154
- if size < 5:
155
- size = 5 # Minimum size
156
- else:
157
- size = 10 # Default size
158
  node_size.append(size)
159
- node_text.append("") # Hide company labels by default
160
- node_color.append("#b2df8a") # Light green color for companies
161
-
162
  node_trace = go.Scatter(
163
  x=node_x,
164
  y=node_y,
 
165
  text=node_text,
166
- mode='markers',
167
  hoverinfo='text',
168
  marker=dict(
169
  showscale=False,
170
- size=node_size,
171
  color=node_color,
 
 
172
  )
173
  )
174
-
 
175
  fig = go.Figure(data=[edge_trace, node_trace])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  fig.update_layout(
177
- showlegend=False,
178
- title="Venture Networks",
179
- titlefont_size=20,
180
- margin=dict(l=20, r=20, t=50, b=20),
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  hovermode='closest',
182
  width=1200,
183
- height=800
 
 
184
  )
 
185
  return fig
186
 
187
- # Gradio app
188
  def app(selected_country, selected_industry, selected_investors):
 
 
 
189
  investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
190
  graph = generate_graph(investors, filtered_data)
191
- return ', '.join(investors), graph
192
 
193
- # Main function
194
  def main():
 
 
 
 
195
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
196
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
197
  investor_list = sorted(investor_company_mapping.keys())
198
 
199
- with gr.Blocks() as demo:
200
- with gr.Row():
201
- country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
202
- industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
203
- investor_filter = gr.Dropdown(choices=investor_list, label="Investor", value=[], multiselect=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- investor_output = gr.Textbox(label="Filtered Investors")
206
- graph_output = gr.Plot(label="Network Graph")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
 
208
  inputs = [country_filter, industry_filter, investor_filter]
209
  outputs = [investor_output, graph_output]
210
 
211
- country_filter.change(app, inputs, outputs)
212
- industry_filter.change(app, inputs, outputs)
213
- investor_filter.change(app, inputs, outputs)
 
 
 
 
 
 
 
 
214
 
215
- demo.launch()
 
216
 
217
  if __name__ == "__main__":
218
  main()
 
4
  import gradio as gr
5
  import logging
6
 
7
+ # -------------------- Setup Logging --------------------
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
+ # -------------------- Load and Preprocess Data --------------------
12
+ FILE_PATH = "cbinsights_data.csv" # Replace with your actual file path
13
 
14
  try:
15
+ data = pd.read_csv(FILE_PATH, skiprows=1)
16
  logger.info("CSV file loaded successfully.")
17
  except FileNotFoundError:
18
+ logger.error(f"File not found: {FILE_PATH}")
19
  raise
20
  except Exception as e:
21
  logger.error(f"Error loading CSV file: {e}")
 
35
  logger.info(f"Using valuation column: {valuation_column}")
36
 
37
  # Clean and prepare data
38
+ data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
39
+ data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
40
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
41
  data.rename(columns={
42
  "company": "Company",
 
49
 
50
  logger.info("Data cleaned and columns renamed.")
51
 
52
+ # -------------------- Build Investor-Company Mapping --------------------
53
  def build_investor_company_mapping(df):
54
+ """
55
+ Builds a mapping from investors to the companies they've invested in.
56
+ """
57
  mapping = {}
58
  for _, row in df.iterrows():
59
  company = row["Company"]
 
68
  investor_company_mapping = build_investor_company_mapping(data)
69
  logger.info("Investor to company mapping created.")
70
 
71
+ # -------------------- Filter Investors --------------------
72
  def filter_investors(selected_country, selected_industry, selected_investors):
73
+ """
74
+ Filters the dataset based on selected country, industry, and investors.
75
+ """
76
  filtered_data = data.copy()
77
  if selected_country != "All":
78
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
 
81
  if selected_investors:
82
  pattern = '|'.join(selected_investors)
83
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
84
+
85
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
86
  filtered_investors = list(investor_company_mapping_filtered.keys())
87
  return filtered_investors, filtered_data
88
 
89
+ # -------------------- Generate Network Graph --------------------
90
  def generate_graph(investors, filtered_data):
91
+ """
92
+ Generates an interactive network graph using Plotly.
93
+ """
94
  if not investors:
95
  logger.warning("No investors selected.")
96
+ # Create an empty figure with a message
97
+ fig = go.Figure()
98
+ fig.update_layout(
99
+ title="No data available for the selected filters.",
100
+ xaxis=dict(visible=False),
101
+ yaxis=dict(visible=False),
102
+ annotations=[dict(
103
+ text="Please adjust your filters to display the network graph.",
104
+ showarrow=False,
105
+ xref="paper", yref="paper",
106
+ x=0.5, y=0.5,
107
+ font=dict(size=20)
108
+ )]
109
+ )
110
+ return fig
111
 
112
+ # Define a color-blind friendly palette
 
 
113
  color_palette = [
114
+ "#377eb8", # Blue
115
+ "#e41a1c", # Red
116
+ "#4daf4a", # Green
117
+ "#984ea3", # Purple
118
+ "#ff7f00", # Orange
119
+ "#ffff33", # Yellow
120
+ "#a65628", # Brown
121
+ "#f781bf", # Pink
122
  ]
123
+
124
+ # Assign colors to investors
125
+ unique_investors = investors
126
+ num_colors = len(unique_investors)
127
+ color_palette_extended = color_palette * (num_colors // len(color_palette) + 1)
128
+ investor_color_map = {investor: color_palette_extended[i] for i, investor in enumerate(unique_investors)}
129
+
130
+ # Create the graph
131
  G = nx.Graph()
132
  for investor in investors:
133
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
134
  for company in companies:
135
+ G.add_node(company, type='company', valuation=filtered_data.loc[filtered_data["Company"] == company, "Valuation_Billions"].values[0])
136
+ G.add_node(investor, type='investor')
137
  G.add_edge(investor, company)
138
+
139
+ # Position nodes using spring layout
140
+ pos = nx.spring_layout(G, seed=42, k=0.5)
141
+
142
+ # Prepare edge traces
143
  edge_x = []
144
  edge_y = []
 
145
  for edge in G.edges():
146
  x0, y0 = pos[edge[0]]
147
  x1, y1 = pos[edge[1]]
148
+ edge_x += [x0, x1, None]
149
+ edge_y += [y0, y1, None]
150
+
151
  edge_trace = go.Scatter(
152
  x=edge_x,
153
  y=edge_y,
154
+ line=dict(width=0.5, color='#888'),
155
  hoverinfo='none',
156
  mode='lines'
157
  )
158
+
159
+ # Prepare node traces
160
  node_x = []
161
  node_y = []
162
  node_text = []
163
  node_color = []
164
  node_size = []
165
+ node_type = []
166
+
167
+ for node in G.nodes(data=True):
168
+ x, y = pos[node[0]]
169
  node_x.append(x)
170
  node_y.append(y)
171
+ node_type.append(node[1]['type'])
172
+ if node[1]['type'] == 'investor':
173
+ node_text.append(node[0]) # Investor labels
174
+ node_color.append(investor_color_map[node[0]])
175
  node_size.append(20) # Fixed size for investors
176
  else:
177
+ valuation = node[1]['valuation']
178
+ size = (valuation * 5) if pd.notnull(valuation) else 10 # Scale size
179
+ size = max(size, 5) # Minimum size
 
 
 
 
 
180
  node_size.append(size)
181
+ node_text.append(f"{node[0]}<br>Valuation: ${valuation}B" if pd.notnull(valuation) else f"{node[0]}<br>Valuation: N/A")
182
+ node_color.append("#b2df8a") # Light green for companies
183
+
184
  node_trace = go.Scatter(
185
  x=node_x,
186
  y=node_y,
187
+ mode='markers+text',
188
  text=node_text,
189
+ textposition="top center",
190
  hoverinfo='text',
191
  marker=dict(
192
  showscale=False,
 
193
  color=node_color,
194
+ size=node_size,
195
+ line=dict(width=1, color='white')
196
  )
197
  )
198
+
199
+ # Create the figure
200
  fig = go.Figure(data=[edge_trace, node_trace])
201
+
202
+ # Add legend manually
203
+ investor_colors = list(investor_color_map.values())[:8] # Limit to first 8 for legend
204
+ investor_names = list(investor_color_map.keys())[:8]
205
+
206
+ for i, investor in enumerate(investor_names):
207
+ fig.add_trace(go.Scatter(
208
+ x=[None],
209
+ y=[None],
210
+ mode='markers',
211
+ marker=dict(
212
+ size=10,
213
+ color=investor_color_map[investor]
214
+ ),
215
+ legendgroup='Investors',
216
+ showlegend=True,
217
+ name=investor
218
+ ))
219
+
220
+ # Update layout for better aesthetics
221
  fig.update_layout(
222
+ title={
223
+ 'text': "Venture Networks",
224
+ 'y':0.95,
225
+ 'x':0.5,
226
+ 'xanchor': 'center',
227
+ 'yanchor': 'top'
228
+ },
229
+ titlefont_size=24,
230
+ showlegend=True,
231
+ legend=dict(
232
+ title="Top Investors",
233
+ itemsizing='constant',
234
+ itemclick='toggleothers',
235
+ itemdoubleclick='toggle',
236
+ font=dict(size=10)
237
+ ),
238
+ margin=dict(l=40, r=40, t=80, b=40),
239
  hovermode='closest',
240
  width=1200,
241
+ height=800,
242
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
243
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
244
  )
245
+
246
  return fig
247
 
248
+ # -------------------- Gradio Application --------------------
249
  def app(selected_country, selected_industry, selected_investors):
250
+ """
251
+ Main application function that filters data and generates the network graph.
252
+ """
253
  investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
254
  graph = generate_graph(investors, filtered_data)
255
+ return ', '.join(investors) if investors else "No investors found.", graph
256
 
 
257
  def main():
258
+ """
259
+ Initializes and launches the Gradio interface.
260
+ """
261
+ # Prepare dropdown options
262
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
263
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
264
  investor_list = sorted(investor_company_mapping.keys())
265
 
266
+ # Define Gradio Blocks
267
+ with gr.Blocks(css="""
268
+ .gradio-container {
269
+ background-color: #f9f9f9;
270
+ padding: 20px;
271
+ }
272
+ .gradio-row {
273
+ justify-content: center;
274
+ }
275
+ """) as demo:
276
+ gr.Markdown("""
277
+ # Venture Networks Visualization
278
+
279
+ Explore the relationships between investors and companies across different countries and industries. Use the filters below to customize the network graph.
280
+
281
+ **Instructions:**
282
+ - Select a country and/or industry to filter the data.
283
+ - Choose one or more investors to focus on specific investment activities.
284
+ - Hover over company nodes to view their valuations.
285
+ """)
286
 
287
+ with gr.Row():
288
+ with gr.Column(scale=1):
289
+ country_filter = gr.Dropdown(
290
+ choices=country_list,
291
+ label="Country",
292
+ value="All",
293
+ info="Select a country to filter the data."
294
+ )
295
+ industry_filter = gr.Dropdown(
296
+ choices=industry_list,
297
+ label="Industry",
298
+ value="All",
299
+ info="Select an industry to filter the data."
300
+ )
301
+ investor_filter = gr.Dropdown(
302
+ choices=investor_list,
303
+ label="Investor",
304
+ value=[],
305
+ multiselect=True,
306
+ info="Select one or more investors to focus on their investments."
307
+ )
308
+ reset_button = gr.Button("Reset Filters", variant="secondary")
309
+ with gr.Column(scale=3):
310
+ graph_output = gr.Plot(label="Network Graph")
311
+
312
+ with gr.Row():
313
+ investor_output = gr.Textbox(label="Filtered Investors", interactive=False)
314
 
315
+ # Define Inputs and Outputs
316
  inputs = [country_filter, industry_filter, investor_filter]
317
  outputs = [investor_output, graph_output]
318
 
319
+ # Define Event Handlers
320
+ country_filter.change(fn=app, inputs=inputs, outputs=outputs)
321
+ industry_filter.change(fn=app, inputs=inputs, outputs=outputs)
322
+ investor_filter.change(fn=app, inputs=inputs, outputs=outputs)
323
+ reset_button.click(fn=lambda: ["", go.Figure()], inputs=None, outputs=outputs)
324
+
325
+ # Add Footer
326
+ gr.Markdown("""
327
+ ---
328
+ © 2024 Venture Networks Visualization Tool
329
+ """)
330
 
331
+ # Launch the Gradio app
332
+ demo.launch()
333
 
334
  if __name__ == "__main__":
335
  main()