import pandas as pd import networkx as nx import plotly.graph_objects as go import gradio as gr import re import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load and preprocess the dataset file_path = "cbinsights_data.csv" # Replace with your actual file path try: data = pd.read_csv(file_path, skiprows=1) logger.info("CSV file loaded successfully.") except FileNotFoundError: logger.error(f"File not found: {file_path}") raise except Exception as e: logger.error(f"Error loading CSV file: {e}") raise # Standardize column names data.columns = data.columns.str.strip().str.lower() logger.info(f"Standardized Column Names: {data.columns.tolist()}") # Identify the valuation column valuation_columns = [col for col in data.columns if 'valuation' in col.lower()] if len(valuation_columns) != 1: logger.error("Unable to identify a single valuation column.") raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.") valuation_column = valuation_columns[0] logger.info(f"Using valuation column: {valuation_column}") # Clean and prepare data data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True) data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce') data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col) data.rename(columns={ "company": "Company", "date_joined": "Date_Joined", "country": "Country", "city": "City", "industry": "Industry", "select_investors": "Select_Investors" }, inplace=True) logger.info("Data cleaned and columns renamed.") # Build investor-company mapping def build_investor_company_mapping(df): mapping = {} for _, row in df.iterrows(): company = row["Company"] investors = row["Select_Investors"] if pd.notnull(investors): for investor in investors.split(","): investor = investor.strip() if investor: mapping.setdefault(investor, []).append(company) return mapping investor_company_mapping = build_investor_company_mapping(data) logger.info("Investor to company mapping created.") # Filter investors by country, industry, investor selection, and company selection def filter_investors(selected_country, selected_industry, selected_investors, selected_company): filtered_data = data.copy() if selected_country != "All": filtered_data = filtered_data[filtered_data["Country"] == selected_country] if selected_industry != "All": filtered_data = filtered_data[filtered_data["Industry"] == selected_industry] if selected_investors: pattern = '|'.join([re.escape(inv) for inv in selected_investors]) filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)] if selected_company != "All": filtered_data = filtered_data[filtered_data["Company"] == selected_company] investor_company_mapping_filtered = build_investor_company_mapping(filtered_data) filtered_investors = list(investor_company_mapping_filtered.keys()) return filtered_investors, filtered_data # Generate Plotly graph def generate_graph(investors, filtered_data): if not investors: logger.warning("No investors selected.") return go.Figure() # Create a color map for investors unique_investors = investors num_colors = len(unique_investors) color_palette = [ "#377eb8", # Blue "#e41a1c", # Red "#4daf4a", # Green "#984ea3", # Purple "#ff7f00", # Orange "#ffff33", # Yellow "#a65628", # Brown "#f781bf", # Pink "#999999", # Grey ] while num_colors > len(color_palette): color_palette.extend(color_palette) investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)} G = nx.Graph() for investor in investors: companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist() for company in companies: G.add_node(company) G.add_node(investor) G.add_edge(investor, company) pos = nx.spring_layout(G, seed=42) edge_x = [] edge_y = [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=0.5, color='#aaaaaa'), hoverinfo='none', mode='lines' ) node_x = [] node_y = [] node_text = [] node_color = [] node_size = [] node_hovertext = [] for node in G.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) if node in investors: node_text.append(node) node_color.append(investor_color_map[node]) node_size.append(30) node_hovertext.append(f"Investor: {node}") else: valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values if len(valuation) > 0 and not pd.isnull(valuation[0]): size = valuation[0] * 5 if size < 10: size = 10 else: size = 15 node_size.append(size) node_text.append("") node_color.append("#a6d854") hovertext = f"Company: {node}" if len(industry) > 0 and not pd.isnull(industry[0]): hovertext += f"
Industry: {industry[0]}" if len(valuation) > 0 and not pd.isnull(valuation[0]): hovertext += f"
Valuation: ${valuation[0]}B" node_hovertext.append(hovertext) node_trace = go.Scatter( x=node_x, y=node_y, text=node_text, mode='markers+text', hoverinfo='text', hovertext=node_hovertext, marker=dict( showscale=False, size=node_size, color=node_color, line=dict(width=0.5, color='#333333') ), textposition="middle center", textfont=dict(size=12, color="#000000") ) legend_items = [] for investor in unique_investors: legend_items.append( go.Scatter( x=[None], y=[None], mode='markers', marker=dict( size=10, color=investor_color_map[investor] ), legendgroup=investor, showlegend=True, name=investor ) ) fig = go.Figure(data=legend_items + [edge_trace, node_trace]) fig.update_layout( title="Venture Networks", titlefont_size=24, margin=dict(l=20, r=20, t=60, b=20), hovermode='closest', width=1200, height=800 ) fig.update_layout( autosize=True, xaxis={'showgrid': False, 'zeroline': False, 'visible': False}, yaxis={'showgrid': False, 'zeroline': False, 'visible': False} ) return fig # Gradio app def app(selected_country, selected_industry, selected_company, selected_investors): investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors, selected_company) if not investors: return "No investors found with the selected filters.", go.Figure() graph = generate_graph(investors, filtered_data) return ', '.join(investors), graph # Main function def main(): country_list = ["All"] + sorted(data["Country"].dropna().unique()) industry_list = ["All"] + sorted(data["Industry"].dropna().unique()) company_list = ["All"] + sorted(data["Company"].dropna().unique()) investor_list = sorted(investor_company_mapping.keys()) with gr.Blocks(title="Venture Networks Visualization") as demo: gr.Markdown(""" # Venture Networks Visualization Explore the connections between investors and companies in the venture capital ecosystem. Use the filters below to customize the network graph. """) with gr.Row(): country_filter = gr.Dropdown(choices=country_list, label="Country", value="All") industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All") company_filter = gr.Dropdown(choices=company_list, label="Company", value="All") investor_filter = gr.Dropdown(choices=investor_list, label="Select Investors", value=[], multiselect=True) with gr.Row(): investor_output = gr.Textbox(label="Filtered Investors", interactive=False) graph_output = gr.Plot(label="Network Graph") inputs = [country_filter, industry_filter, company_filter, investor_filter] outputs = [investor_output, graph_output] country_filter.change(app, inputs, outputs) industry_filter.change(app, inputs, outputs) company_filter.change(app, inputs, outputs) investor_filter.change(app, inputs, outputs) gr.Markdown(""" **Instructions:** Use the dropdowns to filter the network graph. """) demo.launch() if __name__ == "__main__": main()