LeonceNsh commited on
Commit
a7cb589
·
verified ·
1 Parent(s): 9e2bc99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -199
app.py CHANGED
@@ -2,9 +2,7 @@ import pandas as pd
2
  import networkx as nx
3
  import plotly.graph_objects as go
4
  import gradio as gr
5
- import re
6
  import logging
7
- import os
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO)
@@ -13,48 +11,32 @@ logger = logging.getLogger(__name__)
13
  # Load and preprocess the dataset
14
  file_path = "cbinsights_data.csv" # Replace with your actual file path
15
 
16
- def load_data():
17
- if not os.path.exists(file_path):
18
- logger.error(f"File not found: {file_path}")
19
- raise FileNotFoundError(f"File not found: {file_path}")
20
-
21
- try:
22
- data = pd.read_csv(file_path, skiprows=1)
23
- logger.info("CSV file loaded successfully.")
24
- except Exception as e:
25
- logger.error(f"Error loading CSV file: {e}")
26
- raise
27
-
28
- # Standardize column names
29
- data.columns = data.columns.str.strip().str.lower()
30
- logger.info(f"Standardized Column Names: {data.columns.tolist()}")
31
-
32
- # Identify the valuation column
33
- valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
34
- if len(valuation_columns) != 1:
35
- logger.error("Unable to identify a single valuation column.")
36
- raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
37
-
38
- valuation_column = valuation_columns[0]
39
- logger.info(f"Using valuation column: {valuation_column}")
40
-
41
- # Clean and prepare data
42
- data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
43
- data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
44
- data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
45
- data.rename(columns={
46
- "company": "Company",
47
- "date_joined": "Date_Joined",
48
- "country": "Country",
49
- "city": "City",
50
- "industry": "Industry",
51
- "select_investors": "Select_Investors"
52
- }, inplace=True)
53
-
54
- logger.info("Data cleaned and columns renamed.")
55
- return data
56
-
57
- data = load_data()
58
 
59
  # Build investor-company mapping
60
  def build_investor_company_mapping(df):
@@ -72,19 +54,17 @@ def build_investor_company_mapping(df):
72
  investor_company_mapping = build_investor_company_mapping(data)
73
  logger.info("Investor to company mapping created.")
74
 
75
- # Filter investors by country, industry, investor selection, and company selection
76
- def filter_investors(selected_country, selected_industry, selected_investors, selected_company):
77
  filtered_data = data.copy()
78
  if selected_country != "All":
79
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
80
  if selected_industry != "All":
81
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
82
- if selected_company != "All":
83
- filtered_data = filtered_data[filtered_data["Company"] == selected_company]
84
  if selected_investors:
85
- pattern = '|'.join([re.escape(inv) for inv in selected_investors])
86
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
87
-
88
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
89
  filtered_investors = list(investor_company_mapping_filtered.keys())
90
  return filtered_investors, filtered_data
@@ -93,190 +73,77 @@ def filter_investors(selected_country, selected_industry, selected_investors, se
93
  def generate_graph(investors, filtered_data):
94
  if not investors:
95
  logger.warning("No investors selected.")
96
- return go.Figure()
97
-
98
- # Create a color map for investors
99
- unique_investors = investors
100
- num_colors = len(unique_investors)
101
- color_palette = [
102
- "#377eb8", # Blue
103
- "#e41a1c", # Red
104
- "#4daf4a", # Green
105
- "#984ea3", # Purple
106
- "#ff7f00", # Orange
107
- "#ffff33", # Yellow
108
- "#a65628", # Brown
109
- "#f781bf", # Pink
110
- "#999999", # Grey
111
- ]
112
- while num_colors > len(color_palette):
113
- color_palette.extend(color_palette)
114
-
115
- investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
116
 
117
  G = nx.Graph()
118
  for investor in investors:
119
- companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
120
  for company in companies:
121
- G.add_node(company)
122
- G.add_node(investor)
123
  G.add_edge(investor, company)
124
 
125
  pos = nx.spring_layout(G, seed=42)
126
- edge_x = []
127
- edge_y = []
128
-
129
- for edge in G.edges():
130
- x0, y0 = pos[edge[0]]
131
- x1, y1 = pos[edge[1]]
132
- edge_x.extend([x0, x1, None])
133
- edge_y.extend([y0, y1, None])
134
-
135
  edge_trace = go.Scatter(
136
- x=edge_x,
137
- y=edge_y,
138
- line=dict(width=0.5, color='#aaaaaa'),
139
  hoverinfo='none',
140
  mode='lines'
141
  )
142
 
143
- node_x = []
144
- node_y = []
145
- node_text = []
146
- node_color = []
147
- node_size = []
148
- node_hovertext = []
149
-
150
- for node in G.nodes():
151
- x, y = pos[node]
152
- node_x.append(x)
153
- node_y.append(y)
154
- if node in investors:
155
- node_text.append(node)
156
- node_color.append(investor_color_map[node])
157
- node_size.append(30)
158
- node_hovertext.append(f"Investor: {node}")
159
- else:
160
- valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
161
- industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
162
- if len(valuation) > 0 and not pd.isnull(valuation[0]):
163
- size = valuation[0] * 5
164
- if size < 10:
165
- size = 10
166
- else:
167
- size = 15
168
- node_size.append(size)
169
- node_text.append("")
170
- node_color.append("#a6d854")
171
- hovertext = f"Company: {node}"
172
- if len(industry) > 0 and not pd.isnull(industry[0]):
173
- hovertext += f"<br>Industry: {industry[0]}"
174
- if len(valuation) > 0 and not pd.isnull(valuation[0]):
175
- hovertext += f"<br>Valuation: ${valuation[0]}B"
176
- node_hovertext.append(hovertext)
177
-
178
  node_trace = go.Scatter(
179
- x=node_x,
180
- y=node_y,
181
- text=node_text,
182
  mode='markers+text',
183
  hoverinfo='text',
184
- hovertext=node_hovertext,
185
  marker=dict(
186
  showscale=False,
187
- size=node_size,
188
- color=node_color,
189
- line=dict(width=0.5, color='#333333')
190
- ),
191
- textposition="middle center",
192
- textfont=dict(size=12, color="#000000")
193
- )
194
-
195
- legend_items = []
196
- for investor in unique_investors:
197
- legend_items.append(
198
- go.Scatter(
199
- x=[None],
200
- y=[None],
201
- mode='markers',
202
- marker=dict(
203
- size=10,
204
- color=investor_color_map[investor]
205
- ),
206
- legendgroup=investor,
207
- showlegend=True,
208
- name=investor
209
- )
210
  )
 
211
 
212
- fig = go.Figure(data=legend_items + [edge_trace, node_trace])
213
  fig.update_layout(
214
  title="Venture Networks",
215
- titlefont_size=24,
216
- margin=dict(l=20, r=20, t=60, b=20),
 
217
  hovermode='closest',
218
- width=1200,
219
- height=800
220
- )
221
-
222
- fig.update_layout(
223
- autosize=True,
224
- xaxis={'showgrid': False, 'zeroline': False, 'visible': False},
225
- yaxis={'showgrid': False, 'zeroline': False, 'visible': False}
226
  )
 
227
 
228
- return fig
229
-
230
- # Gradio app
231
- def app(selected_country, selected_industry, selected_company, selected_investors):
232
- investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors, selected_company)
233
- if not investors:
234
- return "No investors found with the selected filters.", go.Figure()
235
- graph = generate_graph(investors, filtered_data)
236
- return ', '.join(investors), graph
237
 
238
  # Main function
239
  def main():
240
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
241
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
242
- company_list = ["All"] + sorted(data["Company"].dropna().unique())
243
  investor_list = sorted(investor_company_mapping.keys())
244
 
245
- with gr.Blocks(title="Venture Networks Visualization") as demo:
246
- gr.Markdown("""
247
- # Venture Networks Visualization
248
- Explore the connections between investors and companies in the venture capital ecosystem. Use the filters below to customize the network graph.
249
- """)
250
- with gr.Row():
251
- country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
252
- industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
253
- company_filter = gr.Dropdown(choices=company_list, label="Company", value="All")
254
- investor_filter = gr.Dropdown(choices=investor_list, label="Select Investors", value=[], multiselect=True)
255
- with gr.Row():
256
- investor_output = gr.Textbox(label="Filtered Investors", interactive=False)
257
- graph_output = gr.Plot(label="Network Graph")
258
 
259
- inputs = [country_filter, industry_filter, company_filter, investor_filter]
260
- outputs = [investor_output, graph_output]
 
 
 
 
261
 
262
- # Update the graph when any filter changes
263
- country_filter.change(app, inputs, outputs)
264
- industry_filter.change(app, inputs, outputs)
265
- company_filter.change(app, inputs, outputs)
266
- investor_filter.change(app, inputs, outputs)
267
 
268
- gr.Markdown("""
269
- **Instructions:**
270
- - **Country**: Filter companies by country.
271
- - **Industry**: Filter companies by industry.
272
- - **Company**: Select a specific company to focus on.
273
- - **Select Investors**: Choose investors to visualize their network connections.
274
 
275
- **Tips:**
276
- - Hover over nodes to see more information.
277
- - Use the legend to identify investor nodes.
278
- - Adjust filters to refine your network view.
279
- """)
280
 
281
  demo.launch()
282
 
 
2
  import networkx as nx
3
  import plotly.graph_objects as go
4
  import gradio as gr
 
5
  import logging
 
6
 
7
  # Set up logging
8
  logging.basicConfig(level=logging.INFO)
 
11
  # Load and preprocess the dataset
12
  file_path = "cbinsights_data.csv" # Replace with your actual file path
13
 
14
+ try:
15
+ data = pd.read_csv(file_path, skiprows=1)
16
+ logger.info("CSV file loaded successfully.")
17
+ except FileNotFoundError:
18
+ logger.error(f"File not found: {file_path}")
19
+ raise
20
+ except Exception as e:
21
+ logger.error(f"Error loading CSV file: {e}")
22
+ raise
23
+
24
+ # Standardize column names
25
+ data.columns = data.columns.str.strip().str.lower()
26
+ logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
+
28
+ # Clean and prepare data
29
+ data["valuation_billions"] = data["valuation (usd billions)"].replace({'\$': '', ',': ''}, regex=True)
30
+ data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
31
+ data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
32
+ data.rename(columns={
33
+ "company": "Company",
34
+ "country": "Country",
35
+ "industry": "Industry",
36
+ "select_investors": "Select_Investors"
37
+ }, inplace=True)
38
+
39
+ logger.info("Data cleaned and columns renamed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Build investor-company mapping
42
  def build_investor_company_mapping(df):
 
54
  investor_company_mapping = build_investor_company_mapping(data)
55
  logger.info("Investor to company mapping created.")
56
 
57
+ # Filter investors by country, industry, and investor selection
58
+ def filter_investors(selected_country, selected_industry, selected_investors):
59
  filtered_data = data.copy()
60
  if selected_country != "All":
61
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
62
  if selected_industry != "All":
63
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
 
 
64
  if selected_investors:
65
+ pattern = '|'.join(selected_investors)
66
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
67
+
68
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
69
  filtered_investors = list(investor_company_mapping_filtered.keys())
70
  return filtered_investors, filtered_data
 
73
  def generate_graph(investors, filtered_data):
74
  if not investors:
75
  logger.warning("No investors selected.")
76
+ return go.Figure(), "No data available for the selected filters."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  G = nx.Graph()
79
  for investor in investors:
80
+ companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
81
  for company in companies:
 
 
82
  G.add_edge(investor, company)
83
 
84
  pos = nx.spring_layout(G, seed=42)
 
 
 
 
 
 
 
 
 
85
  edge_trace = go.Scatter(
86
+ x=[pos[node][0] for edge in G.edges() for node in edge] + [None],
87
+ y=[pos[node][1] for edge in G.edges() for node in edge] + [None],
88
+ line=dict(width=0.5, color='#888'),
89
  hoverinfo='none',
90
  mode='lines'
91
  )
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  node_trace = go.Scatter(
94
+ x=[pos[node][0] for node in G.nodes()],
95
+ y=[pos[node][1] for node in G.nodes()],
96
+ text=[node for node in G.nodes()],
97
  mode='markers+text',
98
  hoverinfo='text',
 
99
  marker=dict(
100
  showscale=False,
101
+ size=[20 if node in investors else 10 for node in G.nodes()],
102
+ color=[('blue' if node in investors else 'lightgreen') for node in G.nodes()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
+ )
105
 
106
+ fig = go.Figure(data=[edge_trace, node_trace])
107
  fig.update_layout(
108
  title="Venture Networks",
109
+ titlefont_size=16,
110
+ showlegend=False,
111
+ margin=dict(b=0,l=0,r=0,t=40),
112
  hovermode='closest',
113
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
114
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
 
 
 
 
 
 
115
  )
116
+ return fig, ""
117
 
118
+ # Gradio interface
119
+ def app(selected_country, selected_industry, selected_investors):
120
+ investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
121
+ graph, message = generate_graph(investors, filtered_data)
122
+ return message, graph
 
 
 
 
123
 
124
  # Main function
125
  def main():
126
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
127
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
 
128
  investor_list = sorted(investor_company_mapping.keys())
129
 
130
+ demo = gr.Blocks()
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ with demo:
133
+ gr.Markdown("## Venture Network Visualization Tool")
134
+ with gr.Row():
135
+ country_filter = gr.Dropdown(label="Select Country", choices=country_list, value="All")
136
+ industry_filter = gr.Dropdown(label="Select Industry", choices=industry_list, value="All")
137
+ investor_filter = gr.Dropdown(label="Select Investors", choices=investor_list, multiselect=True, value=[])
138
 
139
+ message = gr.Textbox(label="Status Message", visible=False)
140
+ graph_output = gr.Plot()
 
 
 
141
 
142
+ inputs = [country_filter, industry_filter, investor_filter]
143
+ outputs = [message, graph_output]
 
 
 
 
144
 
145
+ for input_widget in inputs:
146
+ input_widget.change(fn=app, inputs=inputs, outputs=outputs)
 
 
 
147
 
148
  demo.launch()
149