LeonceNsh commited on
Commit
97568f5
·
verified ·
1 Parent(s): a7cb589

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -50
app.py CHANGED
@@ -2,6 +2,7 @@ import pandas as pd
2
  import networkx as nx
3
  import plotly.graph_objects as go
4
  import gradio as gr
 
5
  import logging
6
 
7
  # Set up logging
@@ -25,13 +26,24 @@ except Exception as e:
25
  data.columns = data.columns.str.strip().str.lower()
26
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
 
 
 
 
 
 
 
 
 
 
28
  # Clean and prepare data
29
- data["valuation_billions"] = data["valuation (usd billions)"].replace({'\$': '', ',': ''}, regex=True)
30
- data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
31
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
32
  data.rename(columns={
33
  "company": "Company",
 
34
  "country": "Country",
 
35
  "industry": "Industry",
36
  "select_investors": "Select_Investors"
37
  }, inplace=True)
@@ -54,17 +66,19 @@ def build_investor_company_mapping(df):
54
  investor_company_mapping = build_investor_company_mapping(data)
55
  logger.info("Investor to company mapping created.")
56
 
57
- # Filter investors by country, industry, and investor selection
58
- def filter_investors(selected_country, selected_industry, selected_investors):
59
  filtered_data = data.copy()
60
  if selected_country != "All":
61
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
62
  if selected_industry != "All":
63
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
64
  if selected_investors:
65
- pattern = '|'.join(selected_investors)
66
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
67
-
 
 
68
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
69
  filtered_investors = list(investor_company_mapping_filtered.keys())
70
  return filtered_investors, filtered_data
@@ -73,78 +87,181 @@ def filter_investors(selected_country, selected_industry, selected_investors):
73
  def generate_graph(investors, filtered_data):
74
  if not investors:
75
  logger.warning("No investors selected.")
76
- return go.Figure(), "No data available for the selected filters."
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  G = nx.Graph()
79
  for investor in investors:
80
- companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
81
  for company in companies:
 
 
82
  G.add_edge(investor, company)
83
-
84
  pos = nx.spring_layout(G, seed=42)
 
 
 
 
 
 
 
 
 
85
  edge_trace = go.Scatter(
86
- x=[pos[node][0] for edge in G.edges() for node in edge] + [None],
87
- y=[pos[node][1] for edge in G.edges() for node in edge] + [None],
88
- line=dict(width=0.5, color='#888'),
89
  hoverinfo='none',
90
  mode='lines'
91
  )
92
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  node_trace = go.Scatter(
94
- x=[pos[node][0] for node in G.nodes()],
95
- y=[pos[node][1] for node in G.nodes()],
96
- text=[node for node in G.nodes()],
97
  mode='markers+text',
98
  hoverinfo='text',
 
99
  marker=dict(
100
  showscale=False,
101
- size=[20 if node in investors else 10 for node in G.nodes()],
102
- color=[('blue' if node in investors else 'lightgreen') for node in G.nodes()]
103
- )
 
 
 
104
  )
105
-
106
- fig = go.Figure(data=[edge_trace, node_trace])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  fig.update_layout(
108
  title="Venture Networks",
109
- titlefont_size=16,
110
- showlegend=False,
111
- margin=dict(b=0,l=0,r=0,t=40),
112
  hovermode='closest',
113
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
114
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
 
 
 
 
 
 
115
  )
116
- return fig, ""
 
117
 
118
- # Gradio interface
119
- def app(selected_country, selected_industry, selected_investors):
120
- investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
121
- graph, message = generate_graph(investors, filtered_data)
122
- return message, graph
 
 
123
 
124
  # Main function
125
  def main():
126
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
127
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
 
128
  investor_list = sorted(investor_company_mapping.keys())
129
-
130
- demo = gr.Blocks()
131
-
132
- with demo:
133
- gr.Markdown("## Venture Network Visualization Tool")
 
134
  with gr.Row():
135
- country_filter = gr.Dropdown(label="Select Country", choices=country_list, value="All")
136
- industry_filter = gr.Dropdown(label="Select Industry", choices=industry_list, value="All")
137
- investor_filter = gr.Dropdown(label="Select Investors", choices=investor_list, multiselect=True, value=[])
138
-
139
- message = gr.Textbox(label="Status Message", visible=False)
140
- graph_output = gr.Plot()
141
-
142
- inputs = [country_filter, industry_filter, investor_filter]
143
- outputs = [message, graph_output]
144
-
145
- for input_widget in inputs:
146
- input_widget.change(fn=app, inputs=inputs, outputs=outputs)
147
-
 
 
 
 
 
 
 
148
  demo.launch()
149
 
150
  if __name__ == "__main__":
 
2
  import networkx as nx
3
  import plotly.graph_objects as go
4
  import gradio as gr
5
+ import re
6
  import logging
7
 
8
  # Set up logging
 
26
  data.columns = data.columns.str.strip().str.lower()
27
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
 
29
+ # Identify the valuation column
30
+ valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
31
+ if len(valuation_columns) != 1:
32
+ logger.error("Unable to identify a single valuation column.")
33
+ raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
34
+
35
+ valuation_column = valuation_columns[0]
36
+ logger.info(f"Using valuation column: {valuation_column}")
37
+
38
  # Clean and prepare data
39
+ data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
40
+ data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
41
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
42
  data.rename(columns={
43
  "company": "Company",
44
+ "date_joined": "Date_Joined",
45
  "country": "Country",
46
+ "city": "City",
47
  "industry": "Industry",
48
  "select_investors": "Select_Investors"
49
  }, inplace=True)
 
66
  investor_company_mapping = build_investor_company_mapping(data)
67
  logger.info("Investor to company mapping created.")
68
 
69
+ # Filter investors by country, industry, investor selection, and company selection
70
+ def filter_investors(selected_country, selected_industry, selected_investors, selected_company):
71
  filtered_data = data.copy()
72
  if selected_country != "All":
73
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
74
  if selected_industry != "All":
75
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
76
  if selected_investors:
77
+ pattern = '|'.join([re.escape(inv) for inv in selected_investors])
78
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
79
+ if selected_company != "All":
80
+ filtered_data = filtered_data[filtered_data["Company"] == selected_company]
81
+
82
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
83
  filtered_investors = list(investor_company_mapping_filtered.keys())
84
  return filtered_investors, filtered_data
 
87
  def generate_graph(investors, filtered_data):
88
  if not investors:
89
  logger.warning("No investors selected.")
90
+ return go.Figure()
91
 
92
+ # Create a color map for investors
93
+ unique_investors = investors
94
+ num_colors = len(unique_investors)
95
+ color_palette = [
96
+ "#377eb8", # Blue
97
+ "#e41a1c", # Red
98
+ "#4daf4a", # Green
99
+ "#984ea3", # Purple
100
+ "#ff7f00", # Orange
101
+ "#ffff33", # Yellow
102
+ "#a65628", # Brown
103
+ "#f781bf", # Pink
104
+ "#999999", # Grey
105
+ ]
106
+ while num_colors > len(color_palette):
107
+ color_palette.extend(color_palette)
108
+
109
+ investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
110
+
111
  G = nx.Graph()
112
  for investor in investors:
113
+ companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
114
  for company in companies:
115
+ G.add_node(company)
116
+ G.add_node(investor)
117
  G.add_edge(investor, company)
118
+
119
  pos = nx.spring_layout(G, seed=42)
120
+ edge_x = []
121
+ edge_y = []
122
+
123
+ for edge in G.edges():
124
+ x0, y0 = pos[edge[0]]
125
+ x1, y1 = pos[edge[1]]
126
+ edge_x.extend([x0, x1, None])
127
+ edge_y.extend([y0, y1, None])
128
+
129
  edge_trace = go.Scatter(
130
+ x=edge_x,
131
+ y=edge_y,
132
+ line=dict(width=0.5, color='#aaaaaa'),
133
  hoverinfo='none',
134
  mode='lines'
135
  )
136
+
137
+ node_x = []
138
+ node_y = []
139
+ node_text = []
140
+ node_color = []
141
+ node_size = []
142
+ node_hovertext = []
143
+
144
+ for node in G.nodes():
145
+ x, y = pos[node]
146
+ node_x.append(x)
147
+ node_y.append(y)
148
+ if node in investors:
149
+ node_text.append(node)
150
+ node_color.append(investor_color_map[node])
151
+ node_size.append(30)
152
+ node_hovertext.append(f"Investor: {node}")
153
+ else:
154
+ valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
155
+ industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
156
+ if len(valuation) > 0 and not pd.isnull(valuation[0]):
157
+ size = valuation[0] * 5
158
+ if size < 10:
159
+ size = 10
160
+ else:
161
+ size = 15
162
+ node_size.append(size)
163
+ node_text.append("")
164
+ node_color.append("#a6d854")
165
+ hovertext = f"Company: {node}"
166
+ if len(industry) > 0 and not pd.isnull(industry[0]):
167
+ hovertext += f"<br>Industry: {industry[0]}"
168
+ if len(valuation) > 0 and not pd.isnull(valuation[0]):
169
+ hovertext += f"<br>Valuation: ${valuation[0]}B"
170
+ node_hovertext.append(hovertext)
171
+
172
  node_trace = go.Scatter(
173
+ x=node_x,
174
+ y=node_y,
175
+ text=node_text,
176
  mode='markers+text',
177
  hoverinfo='text',
178
+ hovertext=node_hovertext,
179
  marker=dict(
180
  showscale=False,
181
+ size=node_size,
182
+ color=node_color,
183
+ line=dict(width=0.5, color='#333333')
184
+ ),
185
+ textposition="middle center",
186
+ textfont=dict(size=12, color="#000000")
187
  )
188
+
189
+ legend_items = []
190
+ for investor in unique_investors:
191
+ legend_items.append(
192
+ go.Scatter(
193
+ x=[None],
194
+ y=[None],
195
+ mode='markers',
196
+ marker=dict(
197
+ size=10,
198
+ color=investor_color_map[investor]
199
+ ),
200
+ legendgroup=investor,
201
+ showlegend=True,
202
+ name=investor
203
+ )
204
+ )
205
+
206
+ fig = go.Figure(data=legend_items + [edge_trace, node_trace])
207
  fig.update_layout(
208
  title="Venture Networks",
209
+ titlefont_size=24,
210
+ margin=dict(l=20, r=20, t=60, b=20),
 
211
  hovermode='closest',
212
+ width=1200,
213
+ height=800
214
+ )
215
+
216
+ fig.update_layout(
217
+ autosize=True,
218
+ xaxis={'showgrid': False, 'zeroline': False, 'visible': False},
219
+ yaxis={'showgrid': False, 'zeroline': False, 'visible': False}
220
  )
221
+
222
+ return fig
223
 
224
+ # Gradio app
225
+ def app(selected_country, selected_industry, selected_company, selected_investors):
226
+ investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors, selected_company)
227
+ if not investors:
228
+ return "No investors found with the selected filters.", go.Figure()
229
+ graph = generate_graph(investors, filtered_data)
230
+ return ', '.join(investors), graph
231
 
232
  # Main function
233
  def main():
234
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
235
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
236
+ company_list = ["All"] + sorted(data["Company"].dropna().unique())
237
  investor_list = sorted(investor_company_mapping.keys())
238
+
239
+ with gr.Blocks(title="Venture Networks Visualization") as demo:
240
+ gr.Markdown("""
241
+ # Venture Networks Visualization
242
+ Explore the connections between investors and companies in the venture capital ecosystem. Use the filters below to customize the network graph.
243
+ """)
244
  with gr.Row():
245
+ country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
246
+ industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
247
+ company_filter = gr.Dropdown(choices=company_list, label="Company", value="All")
248
+ investor_filter = gr.Dropdown(choices=investor_list, label="Select Investors", value=[], multiselect=True)
249
+ with gr.Row():
250
+ investor_output = gr.Textbox(label="Filtered Investors", interactive=False)
251
+ graph_output = gr.Plot(label="Network Graph")
252
+
253
+ inputs = [country_filter, industry_filter, company_filter, investor_filter]
254
+ outputs = [investor_output, graph_output]
255
+
256
+ country_filter.change(app, inputs, outputs)
257
+ industry_filter.change(app, inputs, outputs)
258
+ company_filter.change(app, inputs, outputs)
259
+ investor_filter.change(app, inputs, outputs)
260
+
261
+ gr.Markdown("""
262
+ **Instructions:** Use the dropdowns to filter the network graph.
263
+ """)
264
+
265
  demo.launch()
266
 
267
  if __name__ == "__main__":