LeonceNsh commited on
Commit
5c05d7a
·
verified ·
1 Parent(s): e6abf6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -131
app.py CHANGED
@@ -1,144 +1,268 @@
 
1
  import networkx as nx
2
- import matplotlib.pyplot as plt
3
  import gradio as gr
4
- from io import BytesIO
5
- from PIL import Image
6
-
7
- # Define investors and their companies
8
- investors = {
9
- "Accel": ["Meta", "Dropbox", "Spotify", "Adroll", "PackLink", "Zoom", "Slack"],
10
- "Andreessen Horowitz": [
11
- "Airbnb", "Lyft", "Pinterest", "Coinbase", "Robinhood", "Slack"
12
- ],
13
- "Google Ventures": ["Uber", "LendingClub"],
14
- "Greylock": ["Workday", "Palo Alto Networks"],
15
- "Lightspeed Venture Partners": ["Snap", "Grubhub", "AppDynamics"],
16
- "Benchmark": ["Snap", "Uber", "WeWork"],
17
- "Norwest Venture Partners": ["LendingClub", "Opendoor"],
18
- "Emergence Capital Partners": ["Zoom", "Box", "Salesforce"],
19
- "Trinity Ventures": ["New Relic", "Care.com", "TubeMogul"],
20
- "Citi Ventures": ["Square", "Nutanix"],
21
- "Sequoia": ["Alphabet (Google)", "NVIDIA", "Dropbox", "Airbnb", "Meta"],
22
- "Y Combinator": ["Dropbox", "Airbnb", "Coinbase", "DoorDash", "Reddit"]
23
- }
24
-
25
- # Example market capitalization values (in billions USD)
26
- market_cap = {
27
- "Meta": 900,
28
- "Dropbox": 10,
29
- "Spotify": 30,
30
- "Zoom": 20,
31
- "Slack": 27,
32
- "Airbnb": 100,
33
- "Lyft": 4,
34
- "Pinterest": 14,
35
- "Coinbase": 70,
36
- "Robinhood": 10,
37
- "Uber": 60,
38
- "LendingClub": 1,
39
- "Snap": 18,
40
- "Grubhub": 6,
41
- "AppDynamics": 1,
42
- "WeWork": 0.9,
43
- "Opendoor": 3,
44
- "Box": 4,
45
- "Salesforce": 200,
46
- "Square": 90,
47
- "Nutanix": 10,
48
- "Alphabet (Google)": 1500,
49
- "NVIDIA": 1200
50
- }
51
-
52
- # Assign default size for missing companies
53
- default_size = 5
54
-
55
- # Define a color map for the investors
56
- investor_colors = {
57
- "Accel": "#1f77b4",
58
- "Andreessen Horowitz": "#ff7f0e",
59
- "Google Ventures": "#2ca02c",
60
- "Greylock": "#d62728",
61
- "Lightspeed Venture Partners": "#9467bd",
62
- "Benchmark": "#8c564b",
63
- "Norwest Venture Partners": "#e377c2",
64
- "Emergence Capital Partners": "#7f7f7f",
65
- "Trinity Ventures": "#bcbd22",
66
- "Citi Ventures": "#17becf",
67
- "Sequoia": "#1b9e77",
68
- "Y Combinator": "#d95f02"
69
- }
70
-
71
- def generate_graph(selected_investors):
72
- if not selected_investors:
73
- selected_investors = list(investors.keys())
74
 
75
- G = nx.Graph()
 
 
76
 
77
- # Add edges and nodes based on selected investors
78
- for investor in selected_investors:
79
- companies = investors[investor]
80
- for company in companies:
81
- G.add_edge(investor, company, color=investor_colors[investor])
82
-
83
- # Get edge colors
84
- edge_colors = [G[u][v]['color'] for u, v in G.edges]
85
-
86
- # Set node colors and sizes
87
- node_colors = []
88
- node_sizes = []
89
- for node in G.nodes:
90
- if node in investor_colors:
91
- node_colors.append(investor_colors[node])
92
- node_sizes.append(2000) # Fixed size for investors
93
- else:
94
- node_colors.append("#F0E68C") # Khaki for companies
95
- node_sizes.append(market_cap.get(node, default_size) * 100) # Scale up sizes
96
-
97
- # Create plot
98
- plt.figure(figsize=(18, 18))
99
- pos = nx.spring_layout(G, k=0.2, seed=42) # Fixed seed for consistency
100
- nx.draw(
101
- G, pos,
102
- with_labels=True,
103
- node_size=node_sizes,
104
- node_color=node_colors,
105
- font_size=10,
106
- font_weight="bold",
107
- edge_color=edge_colors,
108
- width=2
109
- )
110
- plt.title("Venture Funded Companies as a Densely Connected Subgraph", fontsize=20)
111
- plt.axis('off')
112
 
113
- # Save plot to a BytesIO object
114
- buf = BytesIO()
115
- plt.savefig(buf, format="png", bbox_inches="tight")
116
- plt.close()
117
- buf.seek(0)
 
 
 
 
118
 
119
- # Convert BytesIO to PIL image
120
- image = Image.open(buf)
121
- return image
122
 
123
- # Define Gradio interface
124
- def main():
125
- # Create a sorted list of investors for better UX
126
- investor_list = sorted(investors.keys())
127
-
128
- iface = gr.Interface(
129
- fn=generate_graph,
130
- inputs=gr.CheckboxGroup(
131
- choices=investor_list,
132
- label="Select Investors",
133
- value=investor_list # Default to all selected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  ),
135
- outputs=gr.Image(type="pil", label="Venture Network Graph"),
136
- title="Venture Networks Visualization",
137
- description="Select investors to visualize their investments in various companies. The graph shows connections between investors and the companies they've invested in. Node sizes represent market capitalization.",
138
- flagging_mode="never"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  )
 
 
140
 
141
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  if __name__ == "__main__":
144
  main()
 
1
+ import pandas as pd
2
  import networkx as nx
3
+ import plotly.graph_objects as go
4
  import gradio as gr
5
+ import re
6
+ import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
+ # Load and preprocess the dataset
13
+ file_path = "cbinsights_data.csv" # Replace with your actual file path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ try:
16
+ data = pd.read_csv(file_path, skiprows=1)
17
+ logger.info("CSV file loaded successfully.")
18
+ except FileNotFoundError:
19
+ logger.error(f"File not found: {file_path}")
20
+ raise
21
+ except Exception as e:
22
+ logger.error(f"Error loading CSV file: {e}")
23
+ raise
24
 
25
+ # Standardize column names
26
+ data.columns = data.columns.str.strip().str.lower()
27
+ logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
 
29
+ # Identify the valuation column
30
+ valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
31
+ if len(valuation_columns) != 1:
32
+ logger.error("Unable to identify a single valuation column.")
33
+ raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
34
+
35
+ valuation_column = valuation_columns[0]
36
+ logger.info(f"Using valuation column: {valuation_column}")
37
+
38
+ # Clean and prepare data
39
+ data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
40
+ data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
41
+ data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
42
+ data.rename(columns={
43
+ "company": "Company",
44
+ "date_joined": "Date_Joined",
45
+ "country": "Country",
46
+ "city": "City",
47
+ "industry": "Industry",
48
+ "select_investors": "Select_Investors"
49
+ }, inplace=True)
50
+
51
+ logger.info("Data cleaned and columns renamed.")
52
+
53
+ # Build investor-company mapping
54
+ def build_investor_company_mapping(df):
55
+ mapping = {}
56
+ for _, row in df.iterrows():
57
+ company = row["Company"]
58
+ investors = row["Select_Investors"]
59
+ if pd.notnull(investors):
60
+ for investor in investors.split(","):
61
+ investor = investor.strip()
62
+ if investor:
63
+ mapping.setdefault(investor, []).append(company)
64
+ return mapping
65
+
66
+ investor_company_mapping = build_investor_company_mapping(data)
67
+ logger.info("Investor to company mapping created.")
68
+
69
+ # Filter investors by country, industry, investor selection, and company selection
70
+ def filter_investors(selected_country, selected_industry, selected_investors, selected_company):
71
+ filtered_data = data.copy()
72
+ if selected_country != "All":
73
+ filtered_data = filtered_data[filtered_data["Country"] == selected_country]
74
+ if selected_industry != "All":
75
+ filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
76
+ if selected_investors:
77
+ pattern = '|'.join([re.escape(inv) for inv in selected_investors])
78
+ filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
79
+ if selected_company != "All":
80
+ filtered_data = filtered_data[filtered_data["Company"] == selected_company]
81
+
82
+ investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
83
+ filtered_investors = list(investor_company_mapping_filtered.keys())
84
+ return filtered_investors, filtered_data
85
+
86
+ # Generate Plotly graph
87
+ def generate_graph(investors, filtered_data):
88
+ if not investors:
89
+ logger.warning("No investors selected.")
90
+ return go.Figure()
91
+
92
+ # Create a color map for investors
93
+ unique_investors = investors
94
+ num_colors = len(unique_investors)
95
+ color_palette = [
96
+ "#377eb8", # Blue
97
+ "#e41a1c", # Red
98
+ "#4daf4a", # Green
99
+ "#984ea3", # Purple
100
+ "#ff7f00", # Orange
101
+ "#ffff33", # Yellow
102
+ "#a65628", # Brown
103
+ "#f781bf", # Pink
104
+ "#999999", # Grey
105
+ ]
106
+ while num_colors > len(color_palette):
107
+ color_palette.extend(color_palette)
108
+
109
+ investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
110
+
111
+ G = nx.Graph()
112
+ for investor in investors:
113
+ companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
114
+ for company in companies:
115
+ G.add_node(company)
116
+ G.add_node(investor)
117
+ G.add_edge(investor, company)
118
+
119
+ pos = nx.spring_layout(G, seed=42)
120
+ edge_x = []
121
+ edge_y = []
122
+
123
+ for edge in G.edges():
124
+ x0, y0 = pos[edge[0]]
125
+ x1, y1 = pos[edge[1]]
126
+ edge_x.extend([x0, x1, None])
127
+ edge_y.extend([y0, y1, None])
128
+
129
+ edge_trace = go.Scatter(
130
+ x=edge_x,
131
+ y=edge_y,
132
+ line=dict(width=0.5, color='#aaaaaa'),
133
+ hoverinfo='none',
134
+ mode='lines'
135
+ )
136
+
137
+ node_x = []
138
+ node_y = []
139
+ node_text = []
140
+ node_color = []
141
+ node_size = []
142
+ node_hovertext = []
143
+
144
+ for node in G.nodes():
145
+ x, y = pos[node]
146
+ node_x.append(x)
147
+ node_y.append(y)
148
+ if node in investors:
149
+ node_text.append(node)
150
+ node_color.append(investor_color_map[node])
151
+ node_size.append(30)
152
+ node_hovertext.append(f"Investor: {node}")
153
+ else:
154
+ valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
155
+ industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
156
+ if len(valuation) > 0 and not pd.isnull(valuation[0]):
157
+ size = valuation[0] * 5
158
+ if size < 10:
159
+ size = 10
160
+ else:
161
+ size = 15
162
+ node_size.append(size)
163
+ node_text.append("")
164
+ node_color.append("#a6d854")
165
+ hovertext = f"Company: {node}"
166
+ if len(industry) > 0 and not pd.isnull(industry[0]):
167
+ hovertext += f"<br>Industry: {industry[0]}"
168
+ if len(valuation) > 0 and not pd.isnull(valuation[0]):
169
+ hovertext += f"<br>Valuation: ${valuation[0]}B"
170
+ node_hovertext.append(hovertext)
171
+
172
+ node_trace = go.Scatter(
173
+ x=node_x,
174
+ y=node_y,
175
+ text=node_text,
176
+ mode='markers+text',
177
+ hoverinfo='text',
178
+ hovertext=node_hovertext,
179
+ marker=dict(
180
+ showscale=False,
181
+ size=node_size,
182
+ color=node_color,
183
+ line=dict(width=0.5, color='#333333')
184
  ),
185
+ textposition="middle center",
186
+ textfont=dict(size=12, color="#000000")
187
+ )
188
+
189
+ legend_items = []
190
+ for investor in unique_investors:
191
+ legend_items.append(
192
+ go.Scatter(
193
+ x=[None],
194
+ y=[None],
195
+ mode='markers',
196
+ marker=dict(
197
+ size=10,
198
+ color=investor_color_map[investor]
199
+ ),
200
+ legendgroup=investor,
201
+ showlegend=True,
202
+ name=investor
203
+ )
204
+ )
205
+
206
+ fig = go.Figure(data=legend_items + [edge_trace, node_trace])
207
+ fig.update_layout(
208
+ title="Venture Networks",
209
+ titlefont_size=24,
210
+ margin=dict(l=20, r=20, t=60, b=20),
211
+ hovermode='closest',
212
+ width=1200,
213
+ height=800
214
+ )
215
+
216
+ fig.update_layout(
217
+ autosize=True,
218
+ xaxis={'showgrid': False, 'zeroline': False, 'visible': False},
219
+ yaxis={'showgrid': False, 'zeroline': False, 'visible': False}
220
  )
221
+
222
+ return fig
223
 
224
+ # Gradio app
225
+ def app(selected_country, selected_industry, selected_company, selected_investors):
226
+ investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors, selected_company)
227
+ if not investors:
228
+ return "No investors found with the selected filters.", go.Figure()
229
+ graph = generate_graph(investors, filtered_data)
230
+ return ', '.join(investors), graph
231
+
232
+ # Main function
233
+ def main():
234
+ country_list = ["All"] + sorted(data["Country"].dropna().unique())
235
+ industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
236
+ company_list = ["All"] + sorted(data["Company"].dropna().unique())
237
+ investor_list = sorted(investor_company_mapping.keys())
238
+
239
+ with gr.Blocks(title="Venture Networks Visualization") as demo:
240
+ gr.Markdown("""
241
+ # Venture Networks Visualization
242
+ Explore the connections between investors and companies in the venture capital ecosystem. Use the filters below to customize the network graph.
243
+ """)
244
+ with gr.Row():
245
+ country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
246
+ industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
247
+ company_filter = gr.Dropdown(choices=company_list, label="Company", value="All")
248
+ investor_filter = gr.Dropdown(choices=investor_list, label="Select Investors", value=[], multiselect=True)
249
+ with gr.Row():
250
+ investor_output = gr.Textbox(label="Filtered Investors", interactive=False)
251
+ graph_output = gr.Plot(label="Network Graph")
252
+
253
+ inputs = [country_filter, industry_filter, company_filter, investor_filter]
254
+ outputs = [investor_output, graph_output]
255
+
256
+ country_filter.change(app, inputs, outputs)
257
+ industry_filter.change(app, inputs, outputs)
258
+ company_filter.change(app, inputs, outputs)
259
+ investor_filter.change(app, inputs, outputs)
260
+
261
+ gr.Markdown("""
262
+ **Instructions:** Use the dropdowns to filter the network graph.
263
+ """)
264
+
265
+ demo.launch()
266
 
267
  if __name__ == "__main__":
268
  main()