networkx-saas / app.py
LeonceNsh's picture
Update app.py
cc9514f verified
raw
history blame
6.56 kB
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
import gradio as gr
# Load and preprocess the dataset
file_path = "cbinsights_data.csv" # Replace with your file path
data = pd.read_csv(file_path, skiprows=1)
# Standardize column names: strip whitespace and convert to lowercase
data.columns = data.columns.str.strip().str.lower()
print("Standardized Column Names:", data.columns.tolist())
# Identify the valuation column dynamically
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
if not valuation_columns:
raise ValueError("No column containing 'Valuation' found in the dataset.")
elif len(valuation_columns) > 1:
raise ValueError("Multiple columns containing 'Valuation' found. Please specify.")
else:
valuation_column = valuation_columns[0]
# Clean and prepare data
data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x)
# Rename columns for consistency (optional)
data = data.rename(columns={
"company": "Company",
"valuation_billions": "Valuation_Billions",
"date_joined": "Date_Joined",
"country": "Country",
"city": "City",
"industry": "Industry",
"select_investors": "Select_Investors"
})
# Parse the "Select_Investors" column to map investors to companies
def build_investor_company_mapping(df):
mapping = {}
for _, row in df.iterrows():
company = row["Company"]
investors = row["Select_Investors"]
if pd.notnull(investors):
for investor in investors.split(","):
investor = investor.strip()
mapping.setdefault(investor, []).append(company)
return mapping
investor_company_mapping = build_investor_company_mapping(data)
# Function to filter investors based on selected country and industry
def filter_investors_by_country_and_industry(selected_country, selected_industry):
filtered_data = data.copy()
if selected_country != "All":
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
if selected_industry != "All":
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
# Calculate total valuation per investor
investor_valuations = {}
for investor, companies in investor_company_mapping_filtered.items():
total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
if total_valuation >= 20: # Investors with >= 20B total valuation
investor_valuations[investor] = total_valuation
return list(investor_valuations.keys()), filtered_data
# Function to generate the graph
def generate_graph(selected_investors, filtered_data):
if not selected_investors:
return None
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors}
# Build the graph
G = nx.Graph()
for investor, companies in filtered_mapping.items():
for company in companies:
G.add_edge(investor, company)
# Node size based on valuation
max_valuation = filtered_data["Valuation_Billions"].max()
node_sizes = []
for node in G.nodes:
if node in filtered_mapping:
node_sizes.append(1500) # Fixed size for investors
else:
valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
size = (valuation / max_valuation) * 1500 if max_valuation else 100
node_sizes.append(size)
# Node color: Investors (orange), Companies (green)
node_colors = ["#FF8C00" if node in filtered_mapping else "#32CD32" for node in G.nodes]
# Draw the graph
plt.figure(figsize=(15, 15))
pos = nx.spring_layout(G, k=0.2, seed=42)
nx.draw(
G, pos,
with_labels=True,
node_size=node_sizes,
node_color=node_colors,
font_size=10,
edge_color="#A9A9A9", # Light gray edges
alpha=0.9
)
# Legend
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], marker='o', color='w', label='Investor', markersize=10, markerfacecolor='#FF8C00'),
Line2D([0], [0], marker='o', color='w', label='Company', markersize=10, markerfacecolor='#32CD32')
]
plt.legend(handles=legend_elements, loc='upper left')
plt.title("Venture Network Visualization", fontsize=20)
plt.axis("off")
# Save plot to BytesIO
buf = BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close()
buf.seek(0)
return Image.open(buf)
# Gradio app function
def app(selected_country, selected_industry):
investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
return gr.CheckboxGroup.update(
choices=investor_list,
value=investor_list,
visible=True
), filtered_data
# Gradio Interface
def main():
country_list = ["All"] + sorted(data["Country"].dropna().unique())
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
with gr.Blocks() as demo:
with gr.Row():
country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors", visible=False)
graph_output = gr.Image(type="pil", label="Venture Network Graph")
filtered_data_holder = gr.State()
country_filter.change(
app,
inputs=[country_filter, industry_filter],
outputs=[filtered_investor_list, filtered_data_holder]
)
industry_filter.change(
app,
inputs=[country_filter, industry_filter],
outputs=[filtered_investor_list, filtered_data_holder]
)
filtered_investor_list.change(
generate_graph,
inputs=[filtered_investor_list, filtered_data_holder],
outputs=graph_output
)
demo.launch()
if __name__ == "__main__":
main()