import pandas as pd import networkx as nx import matplotlib.pyplot as plt from io import BytesIO from PIL import Image import gradio as gr # Load and preprocess the dataset file_path = "cbinsights_data.csv" # Replace with your file path data = pd.read_csv(file_path, skiprows=1) # Standardize column names: strip whitespace and convert to lowercase data.columns = data.columns.str.strip().str.lower() print("Standardized Column Names:", data.columns.tolist()) # Identify the valuation column dynamically valuation_columns = [col for col in data.columns if 'valuation' in col.lower()] if not valuation_columns: raise ValueError("No column containing 'Valuation' found in the dataset.") elif len(valuation_columns) > 1: raise ValueError("Multiple columns containing 'Valuation' found. Please specify.") else: valuation_column = valuation_columns[0] # Clean and prepare data data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True) data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce') data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x) # Rename columns for consistency (optional) data = data.rename(columns={ "company": "Company", "valuation_billions": "Valuation_Billions", "date_joined": "Date_Joined", "country": "Country", "city": "City", "industry": "Industry", "select_investors": "Select_Investors" }) # Parse the "Select_Investors" column to map investors to companies def build_investor_company_mapping(df): mapping = {} for _, row in df.iterrows(): company = row["Company"] investors = row["Select_Investors"] if pd.notnull(investors): for investor in investors.split(","): investor = investor.strip() mapping.setdefault(investor, []).append(company) return mapping investor_company_mapping = build_investor_company_mapping(data) # Function to filter investors based on selected country and industry def filter_investors_by_country_and_industry(selected_country, selected_industry): filtered_data = data.copy() if selected_country != "All": filtered_data = filtered_data[filtered_data["Country"] == selected_country] if selected_industry != "All": filtered_data = filtered_data[filtered_data["Industry"] == selected_industry] investor_company_mapping_filtered = build_investor_company_mapping(filtered_data) # Calculate total valuation per investor investor_valuations = {} for investor, companies in investor_company_mapping_filtered.items(): total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum() if total_valuation >= 20: # Investors with >= 20B total valuation investor_valuations[investor] = total_valuation return list(investor_valuations.keys()), filtered_data # Function to generate the graph def generate_graph(selected_investors, filtered_data): if not selected_investors: return None investor_company_mapping_filtered = build_investor_company_mapping(filtered_data) filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors} # Build the graph G = nx.Graph() for investor, companies in filtered_mapping.items(): for company in companies: G.add_edge(investor, company) # Node size based on valuation max_valuation = filtered_data["Valuation_Billions"].max() node_sizes = [] for node in G.nodes: if node in filtered_mapping: node_sizes.append(1500) # Fixed size for investors else: valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum() size = (valuation / max_valuation) * 1500 if max_valuation else 100 node_sizes.append(size) # Node color: Investors (orange), Companies (green) node_colors = ["#FF8C00" if node in filtered_mapping else "#32CD32" for node in G.nodes] # Draw the graph plt.figure(figsize=(15, 15)) pos = nx.spring_layout(G, k=0.2, seed=42) nx.draw( G, pos, with_labels=True, node_size=node_sizes, node_color=node_colors, font_size=10, edge_color="#A9A9A9", # Light gray edges alpha=0.9 ) # Legend from matplotlib.lines import Line2D legend_elements = [ Line2D([0], [0], marker='o', color='w', label='Investor', markersize=10, markerfacecolor='#FF8C00'), Line2D([0], [0], marker='o', color='w', label='Company', markersize=10, markerfacecolor='#32CD32') ] plt.legend(handles=legend_elements, loc='upper left') plt.title("Venture Network Visualization", fontsize=20) plt.axis("off") # Save plot to BytesIO buf = BytesIO() plt.savefig(buf, format="png", bbox_inches="tight") plt.close() buf.seek(0) return Image.open(buf) # Gradio app function def app(selected_country, selected_industry): investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry) return gr.CheckboxGroup.update( choices=investor_list, value=investor_list, visible=True ), filtered_data # Gradio Interface def main(): country_list = ["All"] + sorted(data["Country"].dropna().unique()) industry_list = ["All"] + sorted(data["Industry"].dropna().unique()) with gr.Blocks() as demo: with gr.Row(): country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All") industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All") filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors", visible=False) graph_output = gr.Image(type="pil", label="Venture Network Graph") filtered_data_holder = gr.State() country_filter.change( app, inputs=[country_filter, industry_filter], outputs=[filtered_investor_list, filtered_data_holder] ) industry_filter.change( app, inputs=[country_filter, industry_filter], outputs=[filtered_investor_list, filtered_data_holder] ) filtered_investor_list.change( generate_graph, inputs=[filtered_investor_list, filtered_data_holder], outputs=graph_output ) demo.launch() if __name__ == "__main__": main()