Spaces:
Sleeping
Sleeping
import pandas as pd | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
from PIL import Image | |
import gradio as gr | |
# Load and preprocess the dataset | |
file_path = "cbinsights_data.csv" # Replace with your file path | |
data = pd.read_csv(file_path) | |
# Rename columns based on the first row and drop the header row | |
data.columns = data.iloc[0] | |
data = data[1:] | |
data.columns = ["Company", "Valuation_Billions", "Date_Joined", "Country", "City", "Industry", "Select_Investors"] | |
# Clean and prepare data | |
data["Valuation_Billions"] = data["Valuation_Billions"].str.replace('$', '').str.split('.').str[0] | |
data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce') | |
data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x) | |
# Parse the "Select_Investors" column to map investors to companies | |
investor_company_mapping = {} | |
for _, row in data.iterrows(): | |
company = row["Company"] | |
investors = row["Select_Investors"] | |
if pd.notnull(investors): | |
for investor in investors.split(","): | |
investor = investor.strip() | |
if investor not in investor_company_mapping: | |
investor_company_mapping[investor] = [] | |
investor_company_mapping[investor].append(company) | |
# Gradio app function | |
def generate_graph(selected_investors, selected_country, selected_industry): | |
filtered_data = data | |
# Apply country filter | |
if selected_country != "All": | |
filtered_data = filtered_data[filtered_data["Country"] == selected_country] | |
# Apply industry filter | |
if selected_industry != "All": | |
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry] | |
# Filter the investor-to-company mapping | |
filtered_mapping = {} | |
for investor, companies in investor_company_mapping.items(): | |
filtered_companies = [c for c in companies if c in filtered_data["Company"].values] | |
if filtered_companies: | |
filtered_mapping[investor] = filtered_companies | |
# Use the filtered mapping to build the graph | |
G = nx.Graph() | |
for investor, companies in filtered_mapping.items(): | |
for company in companies: | |
G.add_edge(investor, company) | |
# Node sizes based on valuation | |
node_sizes = [] | |
for node in G.nodes: | |
if node in filtered_mapping: # Fixed size for investors | |
node_sizes.append(2000) | |
else: # Company size based on valuation | |
valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values | |
node_sizes.append(valuation[0] * 100 if len(valuation) > 0 else 100) | |
# Node colors | |
node_colors = [] | |
for node in G.nodes: | |
if node in filtered_mapping: | |
node_colors.append("#FF5733") # Distinct color for investors | |
else: | |
node_colors.append("#33FF57") # Distinct color for companies | |
# Create the graph plot | |
plt.figure(figsize=(18, 18)) | |
pos = nx.spring_layout(G, k=0.2, seed=42) # Fixed seed for consistent layout | |
nx.draw( | |
G, pos, | |
with_labels=True, | |
node_size=node_sizes, | |
node_color=node_colors, | |
font_size=10, | |
font_weight="bold", | |
edge_color="gray", | |
width=1.5 | |
) | |
plt.title("Venture Funded Companies Visualization", fontsize=20) | |
plt.axis('off') | |
# Save plot to BytesIO object | |
buf = BytesIO() | |
plt.savefig(buf, format="png", bbox_inches="tight") | |
plt.close() | |
buf.seek(0) | |
# Convert BytesIO to PIL image | |
image = Image.open(buf) | |
return image | |
# Gradio Interface | |
def main(): | |
investor_list = sorted(investor_company_mapping.keys()) | |
country_list = ["All"] + sorted(data["Country"].dropna().unique()) | |
industry_list = ["All"] + sorted(data["Industry"].dropna().unique()) | |
iface = gr.Interface( | |
fn=generate_graph, | |
inputs=[ | |
gr.CheckboxGroup( | |
choices=investor_list, | |
label="Select Investors", | |
value=investor_list # Default to all selected | |
), | |
gr.Dropdown( | |
choices=country_list, | |
label="Filter by Country", | |
value="All" # Default to no filter | |
), | |
gr.Dropdown( | |
choices=industry_list, | |
label="Filter by Industry", | |
value="All" # Default to no filter | |
) | |
], | |
outputs=gr.Image(type="pil", label="Venture Network Graph"), | |
title="Venture Networks Visualization", | |
description=( | |
"Select investors and apply optional filters by country and industry " | |
"to visualize their investments. The graph shows connections between " | |
"investors and the companies they've invested in. Node sizes represent company valuations." | |
), | |
flagging_mode="never" | |
) | |
iface.launch() | |
if __name__ == "__main__": | |
main() | |