LeonceNsh commited on
Commit
e6abf6e
·
verified ·
1 Parent(s): 0bbd0c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -253
app.py CHANGED
@@ -1,266 +1,144 @@
1
- import pandas as pd
2
  import networkx as nx
3
- import plotly.graph_objects as go
4
  import gradio as gr
5
- import re
6
- import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Set up logging
9
- logging.basicConfig(level=logging.INFO)
10
- logger = logging.getLogger(__name__)
11
-
12
- # Load and preprocess the dataset
13
- file_path = "cbinsights_data.csv" # Replace with your actual file path
14
-
15
- try:
16
- data = pd.read_csv(file_path, skiprows=1)
17
- logger.info("CSV file loaded successfully.")
18
- except FileNotFoundError:
19
- logger.error(f"File not found: {file_path}")
20
- raise
21
- except Exception as e:
22
- logger.error(f"Error loading CSV file: {e}")
23
- raise
24
-
25
- # Standardize column names
26
- data.columns = data.columns.str.strip().str.lower()
27
- logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
-
29
- # Identify the valuation column
30
- valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
31
- if len(valuation_columns) != 1:
32
- logger.error("Unable to identify a single valuation column.")
33
- raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
34
-
35
- valuation_column = valuation_columns[0]
36
- logger.info(f"Using valuation column: {valuation_column}")
37
-
38
- # Clean and prepare data
39
- data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
40
- data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
41
- data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
42
- data.rename(columns={
43
- "company": "Company",
44
- "date_joined": "Date_Joined",
45
- "country": "Country",
46
- "city": "City",
47
- "industry": "Industry",
48
- "select_investors": "Select_Investors"
49
- }, inplace=True)
50
-
51
- logger.info("Data cleaned and columns renamed.")
52
-
53
- # Build investor-company mapping
54
- def build_investor_company_mapping(df):
55
- mapping = {}
56
- for _, row in df.iterrows():
57
- company = row["Company"]
58
- investors = row["Select_Investors"]
59
- if pd.notnull(investors):
60
- for investor in investors.split(","):
61
- investor = investor.strip()
62
- if investor:
63
- mapping.setdefault(investor, []).append(company)
64
- return mapping
65
-
66
- investor_company_mapping = build_investor_company_mapping(data)
67
- logger.info("Investor to company mapping created.")
68
-
69
- # Filter investors by country, industry, investor selection, and company selection
70
- def filter_investors(selected_country, selected_industry, selected_investors, selected_company):
71
- filtered_data = data.copy()
72
- if selected_country != "All":
73
- filtered_data = filtered_data[filtered_data["Country"] == selected_country]
74
- if selected_industry != "All":
75
- filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
76
- if selected_investors:
77
- pattern = '|'.join([re.escape(inv) for inv in selected_investors])
78
- filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
79
- if selected_company != "All":
80
- filtered_data = filtered_data[filtered_data["Company"] == selected_company]
81
-
82
- investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
83
- filtered_investors = list(investor_company_mapping_filtered.keys())
84
- return filtered_investors, filtered_data
85
-
86
- # Generate Plotly graph
87
- def generate_graph(investors, filtered_data):
88
- if not investors:
89
- logger.warning("No investors selected.")
90
- return go.Figure()
91
-
92
- # Create a color map for investors
93
- unique_investors = investors
94
- num_colors = len(unique_investors)
95
- color_palette = [
96
- "#377eb8", # Blue
97
- "#e41a1c", # Red
98
- "#4daf4a", # Green
99
- "#984ea3", # Purple
100
- "#ff7f00", # Orange
101
- "#ffff33", # Yellow
102
- "#a65628", # Brown
103
- "#f781bf", # Pink
104
- "#999999", # Grey
105
- ]
106
- while num_colors > len(color_palette):
107
- color_palette.extend(color_palette)
108
-
109
- investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
110
-
111
  G = nx.Graph()
112
- for investor in investors:
113
- companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
 
 
114
  for company in companies:
115
- G.add_node(company)
116
- G.add_node(investor)
117
- G.add_edge(investor, company)
118
-
119
- pos = nx.spring_layout(G, seed=142)
120
- edge_x = []
121
- edge_y = []
122
-
123
- for edge in G.edges():
124
- x0, y0 = pos[edge[0]]
125
- x1, y1 = pos[edge[1]]
126
- edge_x.extend([x0, x1, None])
127
- edge_y.extend([y0, y1, None])
128
-
129
- edge_trace = go.Scatter(
130
- x=edge_x,
131
- y=edge_y,
132
- line=dict(width=0.5, color='#aaaaaa'),
133
- hoverinfo='none',
134
- mode='lines'
135
- )
136
-
137
- node_x = []
138
- node_y = []
139
- node_text = []
140
- node_color = []
141
- node_size = []
142
- node_hovertext = []
143
-
144
- for node in G.nodes():
145
- x, y = pos[node]
146
- node_x.append(x)
147
- node_y.append(y)
148
- if node in investors:
149
- node_text.append(node)
150
- node_color.append(investor_color_map[node])
151
- node_size.append(30)
152
- node_hovertext.append(f"Investor: {node}")
153
  else:
154
- valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
155
- industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
156
- if len(valuation) > 0 and not pd.isnull(valuation[0]):
157
- size = valuation[0] * 5
158
- if size < 10:
159
- size = 10
160
- else:
161
- size = 15
162
- node_size.append(size)
163
- node_text.append("")
164
- node_color.append("#a6d854")
165
- hovertext = f"Company: {node}"
166
- if len(industry) > 0 and not pd.isnull(industry[0]):
167
- hovertext += f"<br>Industry: {industry[0]}"
168
- if len(valuation) > 0 and not pd.isnull(valuation[0]):
169
- hovertext += f"<br>Valuation: ${valuation[0]}B"
170
- node_hovertext.append(hovertext)
171
-
172
- node_trace = go.Scatter(
173
- x=node_x,
174
- y=node_y,
175
- text=node_text,
176
- mode='markers+text',
177
- hoverinfo='text',
178
- hovertext=node_hovertext,
179
- marker=dict(
180
- showscale=False,
181
- size=node_size,
182
- color=node_color,
183
- line=dict(width=0.5, color='#333333')
184
- ),
185
- textposition="middle center",
186
- textfont=dict(size=15, color="#000000")
187
- )
188
-
189
- legend_items = []
190
- for investor in unique_investors:
191
- legend_items.append(
192
- go.Scatter(
193
- x=[None],
194
- y=[None],
195
- mode='markers',
196
- marker=dict(
197
- size=14,
198
- color=investor_color_map[investor]
199
- ),
200
- legendgroup=investor,
201
- showlegend=False,
202
- name=investor
203
- )
204
- )
205
-
206
- fig = go.Figure(data=legend_items + [edge_trace, node_trace])
207
- fig.update_layout(
208
- title="Venture Capital Networks in September 2024",
209
- titlefont_size=24,
210
- margin=dict(l=20, r=20, t=20, b=20),
211
- hovermode='closest',
212
- width=1400,
213
- height=1000
214
  )
215
-
216
- fig.update_layout(
217
- autosize=True,
218
- xaxis={'showgrid': False, 'zeroline': False, 'visible': False},
219
- yaxis={'showgrid': False, 'zeroline': False, 'visible': False}
220
- )
221
-
222
- return fig
223
 
224
- # Gradio app
225
- def app(selected_country, selected_industry, selected_company, selected_investors):
226
- investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors, selected_company)
227
- if not investors:
228
- return "No investors found with the selected filters.", go.Figure()
229
- graph = generate_graph(investors, filtered_data)
230
- return ', '.join(investors), graph
231
 
232
- # Main function
233
  def main():
234
- country_list = ["All"] + sorted(data["Country"].dropna().unique())
235
- industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
236
- company_list = ["All"] + sorted(data["Company"].dropna().unique())
237
- investor_list = sorted(investor_company_mapping.keys())
238
-
239
- with gr.Blocks(title="Venture Networks Visualization") as demo:
240
- gr.Markdown("""
241
- # Venture Networks Visualization
242
- Explore the connections between investors and companies in the venture capital ecosystem. Use the filters below to customize the network graph.
243
- """)
244
- with gr.Row():
245
- country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
246
- industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
247
- company_filter = gr.Dropdown(choices=company_list, label="Company", value="All")
248
- investor_filter = gr.Dropdown(choices=investor_list, label="Select Investors", value=[], multiselect=True)
249
- graph_output = gr.Plot(label="Network Graph")
250
-
251
- inputs = [country_filter, industry_filter, company_filter, investor_filter]
252
- outputs = [investor_output, graph_output]
253
-
254
- country_filter.change(app, inputs, outputs)
255
- industry_filter.change(app, inputs, outputs)
256
- company_filter.change(app, inputs, outputs)
257
- investor_filter.change(app, inputs, outputs)
258
-
259
- gr.Markdown("""
260
- **Instructions:** Use the dropdowns to filter the network graph.
261
- """)
262
-
263
- demo.launch()
264
 
265
  if __name__ == "__main__":
266
  main()
 
 
1
  import networkx as nx
2
+ import matplotlib.pyplot as plt
3
  import gradio as gr
4
+ from io import BytesIO
5
+ from PIL import Image
6
+
7
+ # Define investors and their companies
8
+ investors = {
9
+ "Accel": ["Meta", "Dropbox", "Spotify", "Adroll", "PackLink", "Zoom", "Slack"],
10
+ "Andreessen Horowitz": [
11
+ "Airbnb", "Lyft", "Pinterest", "Coinbase", "Robinhood", "Slack"
12
+ ],
13
+ "Google Ventures": ["Uber", "LendingClub"],
14
+ "Greylock": ["Workday", "Palo Alto Networks"],
15
+ "Lightspeed Venture Partners": ["Snap", "Grubhub", "AppDynamics"],
16
+ "Benchmark": ["Snap", "Uber", "WeWork"],
17
+ "Norwest Venture Partners": ["LendingClub", "Opendoor"],
18
+ "Emergence Capital Partners": ["Zoom", "Box", "Salesforce"],
19
+ "Trinity Ventures": ["New Relic", "Care.com", "TubeMogul"],
20
+ "Citi Ventures": ["Square", "Nutanix"],
21
+ "Sequoia": ["Alphabet (Google)", "NVIDIA", "Dropbox", "Airbnb", "Meta"],
22
+ "Y Combinator": ["Dropbox", "Airbnb", "Coinbase", "DoorDash", "Reddit"]
23
+ }
24
+
25
+ # Example market capitalization values (in billions USD)
26
+ market_cap = {
27
+ "Meta": 900,
28
+ "Dropbox": 10,
29
+ "Spotify": 30,
30
+ "Zoom": 20,
31
+ "Slack": 27,
32
+ "Airbnb": 100,
33
+ "Lyft": 4,
34
+ "Pinterest": 14,
35
+ "Coinbase": 70,
36
+ "Robinhood": 10,
37
+ "Uber": 60,
38
+ "LendingClub": 1,
39
+ "Snap": 18,
40
+ "Grubhub": 6,
41
+ "AppDynamics": 1,
42
+ "WeWork": 0.9,
43
+ "Opendoor": 3,
44
+ "Box": 4,
45
+ "Salesforce": 200,
46
+ "Square": 90,
47
+ "Nutanix": 10,
48
+ "Alphabet (Google)": 1500,
49
+ "NVIDIA": 1200
50
+ }
51
+
52
+ # Assign default size for missing companies
53
+ default_size = 5
54
+
55
+ # Define a color map for the investors
56
+ investor_colors = {
57
+ "Accel": "#1f77b4",
58
+ "Andreessen Horowitz": "#ff7f0e",
59
+ "Google Ventures": "#2ca02c",
60
+ "Greylock": "#d62728",
61
+ "Lightspeed Venture Partners": "#9467bd",
62
+ "Benchmark": "#8c564b",
63
+ "Norwest Venture Partners": "#e377c2",
64
+ "Emergence Capital Partners": "#7f7f7f",
65
+ "Trinity Ventures": "#bcbd22",
66
+ "Citi Ventures": "#17becf",
67
+ "Sequoia": "#1b9e77",
68
+ "Y Combinator": "#d95f02"
69
+ }
70
+
71
+ def generate_graph(selected_investors):
72
+ if not selected_investors:
73
+ selected_investors = list(investors.keys())
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  G = nx.Graph()
76
+
77
+ # Add edges and nodes based on selected investors
78
+ for investor in selected_investors:
79
+ companies = investors[investor]
80
  for company in companies:
81
+ G.add_edge(investor, company, color=investor_colors[investor])
82
+
83
+ # Get edge colors
84
+ edge_colors = [G[u][v]['color'] for u, v in G.edges]
85
+
86
+ # Set node colors and sizes
87
+ node_colors = []
88
+ node_sizes = []
89
+ for node in G.nodes:
90
+ if node in investor_colors:
91
+ node_colors.append(investor_colors[node])
92
+ node_sizes.append(2000) # Fixed size for investors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  else:
94
+ node_colors.append("#F0E68C") # Khaki for companies
95
+ node_sizes.append(market_cap.get(node, default_size) * 100) # Scale up sizes
96
+
97
+ # Create plot
98
+ plt.figure(figsize=(18, 18))
99
+ pos = nx.spring_layout(G, k=0.2, seed=42) # Fixed seed for consistency
100
+ nx.draw(
101
+ G, pos,
102
+ with_labels=True,
103
+ node_size=node_sizes,
104
+ node_color=node_colors,
105
+ font_size=10,
106
+ font_weight="bold",
107
+ edge_color=edge_colors,
108
+ width=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
110
+ plt.title("Venture Funded Companies as a Densely Connected Subgraph", fontsize=20)
111
+ plt.axis('off')
112
+
113
+ # Save plot to a BytesIO object
114
+ buf = BytesIO()
115
+ plt.savefig(buf, format="png", bbox_inches="tight")
116
+ plt.close()
117
+ buf.seek(0)
118
 
119
+ # Convert BytesIO to PIL image
120
+ image = Image.open(buf)
121
+ return image
 
 
 
 
122
 
123
+ # Define Gradio interface
124
  def main():
125
+ # Create a sorted list of investors for better UX
126
+ investor_list = sorted(investors.keys())
127
+
128
+ iface = gr.Interface(
129
+ fn=generate_graph,
130
+ inputs=gr.CheckboxGroup(
131
+ choices=investor_list,
132
+ label="Select Investors",
133
+ value=investor_list # Default to all selected
134
+ ),
135
+ outputs=gr.Image(type="pil", label="Venture Network Graph"),
136
+ title="Venture Networks Visualization",
137
+ description="Select investors to visualize their investments in various companies. The graph shows connections between investors and the companies they've invested in. Node sizes represent market capitalization.",
138
+ flagging_mode="never"
139
+ )
140
+
141
+ iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  if __name__ == "__main__":
144
  main()