Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,18 +4,18 @@ import plotly.graph_objects as go
|
|
4 |
import gradio as gr
|
5 |
import logging
|
6 |
|
7 |
-
#
|
8 |
logging.basicConfig(level=logging.INFO)
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
11 |
-
# Load and
|
12 |
-
|
13 |
|
14 |
try:
|
15 |
-
data = pd.read_csv(
|
16 |
logger.info("CSV file loaded successfully.")
|
17 |
except FileNotFoundError:
|
18 |
-
logger.error(f"File not found: {
|
19 |
raise
|
20 |
except Exception as e:
|
21 |
logger.error(f"Error loading CSV file: {e}")
|
@@ -35,8 +35,8 @@ valuation_column = valuation_columns[0]
|
|
35 |
logger.info(f"Using valuation column: {valuation_column}")
|
36 |
|
37 |
# Clean and prepare data
|
38 |
-
data["
|
39 |
-
data["
|
40 |
data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
|
41 |
data.rename(columns={
|
42 |
"company": "Company",
|
@@ -49,8 +49,11 @@ data.rename(columns={
|
|
49 |
|
50 |
logger.info("Data cleaned and columns renamed.")
|
51 |
|
52 |
-
# Build
|
53 |
def build_investor_company_mapping(df):
|
|
|
|
|
|
|
54 |
mapping = {}
|
55 |
for _, row in df.iterrows():
|
56 |
company = row["Company"]
|
@@ -65,8 +68,11 @@ def build_investor_company_mapping(df):
|
|
65 |
investor_company_mapping = build_investor_company_mapping(data)
|
66 |
logger.info("Investor to company mapping created.")
|
67 |
|
68 |
-
# Filter
|
69 |
def filter_investors(selected_country, selected_industry, selected_investors):
|
|
|
|
|
|
|
70 |
filtered_data = data.copy()
|
71 |
if selected_country != "All":
|
72 |
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
|
@@ -75,144 +81,255 @@ def filter_investors(selected_country, selected_industry, selected_investors):
|
|
75 |
if selected_investors:
|
76 |
pattern = '|'.join(selected_investors)
|
77 |
filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
|
78 |
-
|
79 |
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
|
80 |
filtered_investors = list(investor_company_mapping_filtered.keys())
|
81 |
return filtered_investors, filtered_data
|
82 |
|
83 |
-
# Generate
|
84 |
def generate_graph(investors, filtered_data):
|
|
|
|
|
|
|
85 |
if not investors:
|
86 |
logger.warning("No investors selected.")
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
#
|
90 |
-
unique_investors = investors
|
91 |
-
num_colors = len(unique_investors)
|
92 |
color_palette = [
|
93 |
-
"#
|
94 |
-
"#
|
95 |
-
"#
|
96 |
-
"#
|
97 |
-
"#
|
98 |
-
"#
|
99 |
-
"#
|
100 |
-
"#
|
101 |
]
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
108 |
G = nx.Graph()
|
109 |
for investor in investors:
|
110 |
companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
|
111 |
for company in companies:
|
112 |
-
G.add_node(company)
|
113 |
-
G.add_node(investor)
|
114 |
G.add_edge(investor, company)
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
117 |
edge_x = []
|
118 |
edge_y = []
|
119 |
-
|
120 |
for edge in G.edges():
|
121 |
x0, y0 = pos[edge[0]]
|
122 |
x1, y1 = pos[edge[1]]
|
123 |
-
edge_x
|
124 |
-
edge_y
|
125 |
-
|
126 |
edge_trace = go.Scatter(
|
127 |
x=edge_x,
|
128 |
y=edge_y,
|
129 |
-
line=dict(width=
|
130 |
hoverinfo='none',
|
131 |
mode='lines'
|
132 |
)
|
133 |
-
|
|
|
134 |
node_x = []
|
135 |
node_y = []
|
136 |
node_text = []
|
137 |
node_color = []
|
138 |
node_size = []
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
142 |
node_x.append(x)
|
143 |
node_y.append(y)
|
144 |
-
|
145 |
-
|
146 |
-
node_text.append(node) #
|
147 |
-
node_color.append(investor_color_map[node])
|
148 |
node_size.append(20) # Fixed size for investors
|
149 |
else:
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
size = valuation[0] * 5 # Scale size as needed
|
154 |
-
if size < 5:
|
155 |
-
size = 5 # Minimum size
|
156 |
-
else:
|
157 |
-
size = 10 # Default size
|
158 |
node_size.append(size)
|
159 |
-
node_text.append(""
|
160 |
-
node_color.append("#b2df8a") # Light green
|
161 |
-
|
162 |
node_trace = go.Scatter(
|
163 |
x=node_x,
|
164 |
y=node_y,
|
|
|
165 |
text=node_text,
|
166 |
-
|
167 |
hoverinfo='text',
|
168 |
marker=dict(
|
169 |
showscale=False,
|
170 |
-
size=node_size,
|
171 |
color=node_color,
|
|
|
|
|
172 |
)
|
173 |
)
|
174 |
-
|
|
|
175 |
fig = go.Figure(data=[edge_trace, node_trace])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
fig.update_layout(
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
hovermode='closest',
|
182 |
width=1200,
|
183 |
-
height=800
|
|
|
|
|
184 |
)
|
|
|
185 |
return fig
|
186 |
|
187 |
-
# Gradio
|
188 |
def app(selected_country, selected_industry, selected_investors):
|
|
|
|
|
|
|
189 |
investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
|
190 |
graph = generate_graph(investors, filtered_data)
|
191 |
-
return ', '.join(investors), graph
|
192 |
|
193 |
-
# Main function
|
194 |
def main():
|
|
|
|
|
|
|
|
|
195 |
country_list = ["All"] + sorted(data["Country"].dropna().unique())
|
196 |
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
|
197 |
investor_list = sorted(investor_company_mapping.keys())
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
|
|
208 |
inputs = [country_filter, industry_filter, investor_filter]
|
209 |
outputs = [investor_output, graph_output]
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
-
|
|
|
216 |
|
217 |
if __name__ == "__main__":
|
218 |
main()
|
|
|
4 |
import gradio as gr
|
5 |
import logging
|
6 |
|
7 |
+
# -------------------- Setup Logging --------------------
|
8 |
logging.basicConfig(level=logging.INFO)
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
11 |
+
# -------------------- Load and Preprocess Data --------------------
|
12 |
+
FILE_PATH = "cbinsights_data.csv" # Replace with your actual file path
|
13 |
|
14 |
try:
|
15 |
+
data = pd.read_csv(FILE_PATH, skiprows=1)
|
16 |
logger.info("CSV file loaded successfully.")
|
17 |
except FileNotFoundError:
|
18 |
+
logger.error(f"File not found: {FILE_PATH}")
|
19 |
raise
|
20 |
except Exception as e:
|
21 |
logger.error(f"Error loading CSV file: {e}")
|
|
|
35 |
logger.info(f"Using valuation column: {valuation_column}")
|
36 |
|
37 |
# Clean and prepare data
|
38 |
+
data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
|
39 |
+
data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
|
40 |
data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
|
41 |
data.rename(columns={
|
42 |
"company": "Company",
|
|
|
49 |
|
50 |
logger.info("Data cleaned and columns renamed.")
|
51 |
|
52 |
+
# -------------------- Build Investor-Company Mapping --------------------
|
53 |
def build_investor_company_mapping(df):
|
54 |
+
"""
|
55 |
+
Builds a mapping from investors to the companies they've invested in.
|
56 |
+
"""
|
57 |
mapping = {}
|
58 |
for _, row in df.iterrows():
|
59 |
company = row["Company"]
|
|
|
68 |
investor_company_mapping = build_investor_company_mapping(data)
|
69 |
logger.info("Investor to company mapping created.")
|
70 |
|
71 |
+
# -------------------- Filter Investors --------------------
|
72 |
def filter_investors(selected_country, selected_industry, selected_investors):
|
73 |
+
"""
|
74 |
+
Filters the dataset based on selected country, industry, and investors.
|
75 |
+
"""
|
76 |
filtered_data = data.copy()
|
77 |
if selected_country != "All":
|
78 |
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
|
|
|
81 |
if selected_investors:
|
82 |
pattern = '|'.join(selected_investors)
|
83 |
filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
|
84 |
+
|
85 |
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
|
86 |
filtered_investors = list(investor_company_mapping_filtered.keys())
|
87 |
return filtered_investors, filtered_data
|
88 |
|
89 |
+
# -------------------- Generate Network Graph --------------------
|
90 |
def generate_graph(investors, filtered_data):
|
91 |
+
"""
|
92 |
+
Generates an interactive network graph using Plotly.
|
93 |
+
"""
|
94 |
if not investors:
|
95 |
logger.warning("No investors selected.")
|
96 |
+
# Create an empty figure with a message
|
97 |
+
fig = go.Figure()
|
98 |
+
fig.update_layout(
|
99 |
+
title="No data available for the selected filters.",
|
100 |
+
xaxis=dict(visible=False),
|
101 |
+
yaxis=dict(visible=False),
|
102 |
+
annotations=[dict(
|
103 |
+
text="Please adjust your filters to display the network graph.",
|
104 |
+
showarrow=False,
|
105 |
+
xref="paper", yref="paper",
|
106 |
+
x=0.5, y=0.5,
|
107 |
+
font=dict(size=20)
|
108 |
+
)]
|
109 |
+
)
|
110 |
+
return fig
|
111 |
|
112 |
+
# Define a color-blind friendly palette
|
|
|
|
|
113 |
color_palette = [
|
114 |
+
"#377eb8", # Blue
|
115 |
+
"#e41a1c", # Red
|
116 |
+
"#4daf4a", # Green
|
117 |
+
"#984ea3", # Purple
|
118 |
+
"#ff7f00", # Orange
|
119 |
+
"#ffff33", # Yellow
|
120 |
+
"#a65628", # Brown
|
121 |
+
"#f781bf", # Pink
|
122 |
]
|
123 |
+
|
124 |
+
# Assign colors to investors
|
125 |
+
unique_investors = investors
|
126 |
+
num_colors = len(unique_investors)
|
127 |
+
color_palette_extended = color_palette * (num_colors // len(color_palette) + 1)
|
128 |
+
investor_color_map = {investor: color_palette_extended[i] for i, investor in enumerate(unique_investors)}
|
129 |
+
|
130 |
+
# Create the graph
|
131 |
G = nx.Graph()
|
132 |
for investor in investors:
|
133 |
companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
|
134 |
for company in companies:
|
135 |
+
G.add_node(company, type='company', valuation=filtered_data.loc[filtered_data["Company"] == company, "Valuation_Billions"].values[0])
|
136 |
+
G.add_node(investor, type='investor')
|
137 |
G.add_edge(investor, company)
|
138 |
+
|
139 |
+
# Position nodes using spring layout
|
140 |
+
pos = nx.spring_layout(G, seed=42, k=0.5)
|
141 |
+
|
142 |
+
# Prepare edge traces
|
143 |
edge_x = []
|
144 |
edge_y = []
|
|
|
145 |
for edge in G.edges():
|
146 |
x0, y0 = pos[edge[0]]
|
147 |
x1, y1 = pos[edge[1]]
|
148 |
+
edge_x += [x0, x1, None]
|
149 |
+
edge_y += [y0, y1, None]
|
150 |
+
|
151 |
edge_trace = go.Scatter(
|
152 |
x=edge_x,
|
153 |
y=edge_y,
|
154 |
+
line=dict(width=0.5, color='#888'),
|
155 |
hoverinfo='none',
|
156 |
mode='lines'
|
157 |
)
|
158 |
+
|
159 |
+
# Prepare node traces
|
160 |
node_x = []
|
161 |
node_y = []
|
162 |
node_text = []
|
163 |
node_color = []
|
164 |
node_size = []
|
165 |
+
node_type = []
|
166 |
+
|
167 |
+
for node in G.nodes(data=True):
|
168 |
+
x, y = pos[node[0]]
|
169 |
node_x.append(x)
|
170 |
node_y.append(y)
|
171 |
+
node_type.append(node[1]['type'])
|
172 |
+
if node[1]['type'] == 'investor':
|
173 |
+
node_text.append(node[0]) # Investor labels
|
174 |
+
node_color.append(investor_color_map[node[0]])
|
175 |
node_size.append(20) # Fixed size for investors
|
176 |
else:
|
177 |
+
valuation = node[1]['valuation']
|
178 |
+
size = (valuation * 5) if pd.notnull(valuation) else 10 # Scale size
|
179 |
+
size = max(size, 5) # Minimum size
|
|
|
|
|
|
|
|
|
|
|
180 |
node_size.append(size)
|
181 |
+
node_text.append(f"{node[0]}<br>Valuation: ${valuation}B" if pd.notnull(valuation) else f"{node[0]}<br>Valuation: N/A")
|
182 |
+
node_color.append("#b2df8a") # Light green for companies
|
183 |
+
|
184 |
node_trace = go.Scatter(
|
185 |
x=node_x,
|
186 |
y=node_y,
|
187 |
+
mode='markers+text',
|
188 |
text=node_text,
|
189 |
+
textposition="top center",
|
190 |
hoverinfo='text',
|
191 |
marker=dict(
|
192 |
showscale=False,
|
|
|
193 |
color=node_color,
|
194 |
+
size=node_size,
|
195 |
+
line=dict(width=1, color='white')
|
196 |
)
|
197 |
)
|
198 |
+
|
199 |
+
# Create the figure
|
200 |
fig = go.Figure(data=[edge_trace, node_trace])
|
201 |
+
|
202 |
+
# Add legend manually
|
203 |
+
investor_colors = list(investor_color_map.values())[:8] # Limit to first 8 for legend
|
204 |
+
investor_names = list(investor_color_map.keys())[:8]
|
205 |
+
|
206 |
+
for i, investor in enumerate(investor_names):
|
207 |
+
fig.add_trace(go.Scatter(
|
208 |
+
x=[None],
|
209 |
+
y=[None],
|
210 |
+
mode='markers',
|
211 |
+
marker=dict(
|
212 |
+
size=10,
|
213 |
+
color=investor_color_map[investor]
|
214 |
+
),
|
215 |
+
legendgroup='Investors',
|
216 |
+
showlegend=True,
|
217 |
+
name=investor
|
218 |
+
))
|
219 |
+
|
220 |
+
# Update layout for better aesthetics
|
221 |
fig.update_layout(
|
222 |
+
title={
|
223 |
+
'text': "Venture Networks",
|
224 |
+
'y':0.95,
|
225 |
+
'x':0.5,
|
226 |
+
'xanchor': 'center',
|
227 |
+
'yanchor': 'top'
|
228 |
+
},
|
229 |
+
titlefont_size=24,
|
230 |
+
showlegend=True,
|
231 |
+
legend=dict(
|
232 |
+
title="Top Investors",
|
233 |
+
itemsizing='constant',
|
234 |
+
itemclick='toggleothers',
|
235 |
+
itemdoubleclick='toggle',
|
236 |
+
font=dict(size=10)
|
237 |
+
),
|
238 |
+
margin=dict(l=40, r=40, t=80, b=40),
|
239 |
hovermode='closest',
|
240 |
width=1200,
|
241 |
+
height=800,
|
242 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
243 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
244 |
)
|
245 |
+
|
246 |
return fig
|
247 |
|
248 |
+
# -------------------- Gradio Application --------------------
|
249 |
def app(selected_country, selected_industry, selected_investors):
|
250 |
+
"""
|
251 |
+
Main application function that filters data and generates the network graph.
|
252 |
+
"""
|
253 |
investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
|
254 |
graph = generate_graph(investors, filtered_data)
|
255 |
+
return ', '.join(investors) if investors else "No investors found.", graph
|
256 |
|
|
|
257 |
def main():
|
258 |
+
"""
|
259 |
+
Initializes and launches the Gradio interface.
|
260 |
+
"""
|
261 |
+
# Prepare dropdown options
|
262 |
country_list = ["All"] + sorted(data["Country"].dropna().unique())
|
263 |
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
|
264 |
investor_list = sorted(investor_company_mapping.keys())
|
265 |
|
266 |
+
# Define Gradio Blocks
|
267 |
+
with gr.Blocks(css="""
|
268 |
+
.gradio-container {
|
269 |
+
background-color: #f9f9f9;
|
270 |
+
padding: 20px;
|
271 |
+
}
|
272 |
+
.gradio-row {
|
273 |
+
justify-content: center;
|
274 |
+
}
|
275 |
+
""") as demo:
|
276 |
+
gr.Markdown("""
|
277 |
+
# Venture Networks Visualization
|
278 |
+
|
279 |
+
Explore the relationships between investors and companies across different countries and industries. Use the filters below to customize the network graph.
|
280 |
+
|
281 |
+
**Instructions:**
|
282 |
+
- Select a country and/or industry to filter the data.
|
283 |
+
- Choose one or more investors to focus on specific investment activities.
|
284 |
+
- Hover over company nodes to view their valuations.
|
285 |
+
""")
|
286 |
|
287 |
+
with gr.Row():
|
288 |
+
with gr.Column(scale=1):
|
289 |
+
country_filter = gr.Dropdown(
|
290 |
+
choices=country_list,
|
291 |
+
label="Country",
|
292 |
+
value="All",
|
293 |
+
info="Select a country to filter the data."
|
294 |
+
)
|
295 |
+
industry_filter = gr.Dropdown(
|
296 |
+
choices=industry_list,
|
297 |
+
label="Industry",
|
298 |
+
value="All",
|
299 |
+
info="Select an industry to filter the data."
|
300 |
+
)
|
301 |
+
investor_filter = gr.Dropdown(
|
302 |
+
choices=investor_list,
|
303 |
+
label="Investor",
|
304 |
+
value=[],
|
305 |
+
multiselect=True,
|
306 |
+
info="Select one or more investors to focus on their investments."
|
307 |
+
)
|
308 |
+
reset_button = gr.Button("Reset Filters", variant="secondary")
|
309 |
+
with gr.Column(scale=3):
|
310 |
+
graph_output = gr.Plot(label="Network Graph")
|
311 |
+
|
312 |
+
with gr.Row():
|
313 |
+
investor_output = gr.Textbox(label="Filtered Investors", interactive=False)
|
314 |
|
315 |
+
# Define Inputs and Outputs
|
316 |
inputs = [country_filter, industry_filter, investor_filter]
|
317 |
outputs = [investor_output, graph_output]
|
318 |
|
319 |
+
# Define Event Handlers
|
320 |
+
country_filter.change(fn=app, inputs=inputs, outputs=outputs)
|
321 |
+
industry_filter.change(fn=app, inputs=inputs, outputs=outputs)
|
322 |
+
investor_filter.change(fn=app, inputs=inputs, outputs=outputs)
|
323 |
+
reset_button.click(fn=lambda: ["", go.Figure()], inputs=None, outputs=outputs)
|
324 |
+
|
325 |
+
# Add Footer
|
326 |
+
gr.Markdown("""
|
327 |
+
---
|
328 |
+
© 2024 Venture Networks Visualization Tool
|
329 |
+
""")
|
330 |
|
331 |
+
# Launch the Gradio app
|
332 |
+
demo.launch()
|
333 |
|
334 |
if __name__ == "__main__":
|
335 |
main()
|