Spaces:
Sleeping
Sleeping
File size: 6,407 Bytes
a165958 45a7450 9349152 a165958 45a7450 a165958 9327810 e63418f 9327810 e63418f 45a7450 9327810 e63418f 9327810 e63418f 45a7450 a165958 45a7450 a165958 e63418f a165958 e63418f 09b019d 45a7450 a165958 45a7450 e63418f a165958 45a7450 a165958 45a7450 a165958 45a7450 a165958 45a7450 99eb020 45a7450 09b019d 45a7450 09b019d 45a7450 09b019d a165958 45a7450 a165958 45a7450 9349152 45a7450 9327810 a165958 45a7450 e63418f 45a7450 9327810 45a7450 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
import gradio as gr
# Load and preprocess the dataset
file_path = "cbinsights_data.csv" # Replace with your file path
data = pd.read_csv(file_path)
# Rename columns based on the first row and drop the header row
data.columns = data.iloc[0]
data = data[1:]
data.columns = ["Company", "Valuation_Billions", "Date_Joined", "Country", "City", "Industry", "Select_Investors"]
# Clean and prepare data
data["Valuation_Billions"] = data["Valuation_Billions"].str.replace('$', '').str.split('.').str[0]
data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x)
# Parse the "Select_Investors" column to map investors to companies
investor_company_mapping = {}
for _, row in data.iterrows():
company = row["Company"]
investors = row["Select_Investors"]
if pd.notnull(investors):
for investor in investors.split(","):
investor = investor.strip()
if investor not in investor_company_mapping:
investor_company_mapping[investor] = []
investor_company_mapping[investor].append(company)
# Gradio app functions
def filter_investors_by_country_and_industry(selected_country, selected_industry):
filtered_data = data
# Apply filters
if selected_country != "All":
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
if selected_industry != "All":
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
# Calculate total valuation per investor
investor_valuations = {}
for investor, companies in investor_company_mapping.items():
total_valuation = 0
for company in companies:
if company in filtered_data["Company"].values:
valuation = filtered_data.loc[filtered_data["Company"] == company, "Valuation_Billions"].values
total_valuation += valuation[0] if len(valuation) > 0 else 0
if total_valuation >= 20: # Filter by total valuation
investor_valuations[investor] = total_valuation
return list(investor_valuations.keys()), filtered_data
def generate_graph(selected_investors, filtered_data):
if not selected_investors:
return None
# Filter the investor-to-company mapping
filtered_mapping = {}
for investor, companies in investor_company_mapping.items():
if investor in selected_investors:
filtered_companies = [c for c in companies if c in filtered_data["Company"].values]
if filtered_companies:
filtered_mapping[investor] = filtered_companies
# Use the filtered mapping to build the graph
G = nx.Graph()
for investor, companies in filtered_mapping.items():
for company in companies:
G.add_edge(investor, company)
# Node sizes based on valuation
node_sizes = []
for node in G.nodes:
if node in filtered_mapping: # Fixed size for investors
node_sizes.append(2000)
else: # Company size based on valuation
valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
node_sizes.append(valuation[0] * 50 if len(valuation) > 0 else 100)
# Node colors
node_colors = []
for node in G.nodes:
if node in filtered_mapping:
node_colors.append("#FF5733") # Distinct color for investors
else:
node_colors.append("#33FF57") # Distinct color for companies
# Create the graph plot
plt.figure(figsize=(18, 18))
pos = nx.spring_layout(G, k=0.2, seed=42) # Fixed seed for consistent layout
nx.draw(
G, pos,
with_labels=True,
node_size=node_sizes,
node_color=node_colors,
alpha=0.8, # Slight transparency for Tufte-inspired visuals
font_size=10,
font_weight="bold",
edge_color="#B0BEC5", # Neutral, muted edge color
width=0.8 # Thin edges for minimal visual clutter
)
# Add a legend for node size (valuation)
min_size, max_size = 50, 5000 # Example scale
for size, label in zip([min_size, max_size], ["$1B", "$100B"]):
plt.scatter([], [], s=size, color="#33FF57", label=f"{label} valuation")
plt.legend(scatterpoints=1, frameon=False, labelspacing=1.5, loc="lower left", fontsize=12)
plt.title("Venture Funded Companies Visualization", fontsize=20)
plt.axis('off')
# Save plot to BytesIO object
buf = BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close()
buf.seek(0)
# Convert BytesIO to PIL image
image = Image.open(buf)
return image
def app(selected_country, selected_industry):
investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
return gr.update(
choices=investor_list,
value=investor_list,
visible=True
), filtered_data
# Gradio Interface
def main():
country_list = ["All"] + sorted(data["Country"].dropna().unique())
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
with gr.Blocks() as demo:
with gr.Row():
country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
filtered_investor_list = gr.CheckboxGroup(
choices=[],
label="Select Investors",
visible=False
)
graph_output = gr.Image(type="pil", label="Venture Network Graph")
filtered_data_holder = gr.State()
country_filter.change(
app,
inputs=[country_filter, industry_filter],
outputs=[filtered_investor_list, filtered_data_holder]
)
industry_filter.change(
app,
inputs=[country_filter, industry_filter],
outputs=[filtered_investor_list, filtered_data_holder]
)
filtered_investor_list.change(
generate_graph,
inputs=[filtered_investor_list, filtered_data_holder],
outputs=graph_output
)
demo.launch()
if __name__ == "__main__":
main()
|