import pandas as pd import networkx as nx import plotly.graph_objects as go import gradio as gr import re import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load and preprocess the dataset file_path = "cbinsights_data.csv" # Replace with your actual file path try: data = pd.read_csv(file_path, skiprows=1) logger.info("CSV file loaded successfully.") except FileNotFoundError: logger.error(f"File not found: {file_path}") raise except Exception as e: logger.error(f"Error loading CSV file: {e}") raise # Standardize column names data.columns = data.columns.str.strip().str.lower() logger.info(f"Standardized Column Names: {data.columns.tolist()}") # Filter out Health since Healthcare is the correct Market Segment data = data[data.industry != 'Health'] # Identify the valuation column valuation_columns = [col for col in data.columns if 'valuation' in col.lower()] if len(valuation_columns) != 1: logger.error("Unable to identify a single valuation column.") raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.") valuation_column = valuation_columns[0] logger.info(f"Using valuation column: {valuation_column}") # Clean and prepare data data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True) data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce') data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col) data.rename(columns={ "company": "Company", "date_joined": "Date_Joined", "country": "Country", "city": "City", "industry": "Industry", "select_investors": "Select_Investors" }, inplace=True) logger.info("Data cleaned and columns renamed.") # Build investor-company mapping def build_investor_company_mapping(df): mapping = {} for _, row in df.iterrows(): company = row["Company"] investors = row["Select_Investors"] if pd.notnull(investors): for investor in investors.split(","): investor = investor.strip() if investor: mapping.setdefault(investor, []).append(company) return mapping investor_company_mapping = build_investor_company_mapping(data) logger.info("Investor to company mapping created.") # ------------------------- # Valuation-Range Logic # ------------------------- def filter_by_valuation_range(df, selected_valuation_range): """Filter dataframe by the specified valuation range in billions.""" if selected_valuation_range == "All": return df # No further filtering if selected_valuation_range == "1-5": return df[(df["Valuation_Billions"] >= 1) & (df["Valuation_Billions"] < 5)] elif selected_valuation_range == "5-10": return df[(df["Valuation_Billions"] >= 5) & (df["Valuation_Billions"] < 10)] elif selected_valuation_range == "10-15": return df[(df["Valuation_Billions"] >= 10) & (df["Valuation_Billions"] < 15)] elif selected_valuation_range == "15-20": return df[(df["Valuation_Billions"] >= 15) & (df["Valuation_Billions"] < 20)] elif selected_valuation_range == "20+": return df[df["Valuation_Billions"] >= 20] else: return df # Fallback, should never happen # Filter investors by country, industry, investor selection, company selection, and valuation range def filter_investors( selected_country, selected_industry, selected_investors, selected_company, exclude_countries, exclude_industries, exclude_companies, exclude_investors, selected_valuation_range ): filtered_data = data.copy() # 1) Valuation range filter filtered_data = filter_by_valuation_range(filtered_data, selected_valuation_range) # 2) Now apply the existing filters: # Inclusion filters if selected_country != "All": filtered_data = filtered_data[filtered_data["Country"] == selected_country] if selected_industry != "All": filtered_data = filtered_data[filtered_data["Industry"] == selected_industry] if selected_company != "All": filtered_data = filtered_data[filtered_data["Company"] == selected_company] if selected_investors: pattern = '|'.join([re.escape(inv) for inv in selected_investors]) filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)] # Exclusion filters if exclude_countries: filtered_data = filtered_data[~filtered_data["Country"].isin(exclude_countries)] if exclude_industries: filtered_data = filtered_data[~filtered_data["Industry"].isin(exclude_industries)] if exclude_companies: filtered_data = filtered_data[~filtered_data["Company"].isin(exclude_companies)] if exclude_investors: pattern = '|'.join([re.escape(inv) for inv in exclude_investors]) filtered_data = filtered_data[~filtered_data["Select_Investors"].str.contains(pattern, na=False)] investor_company_mapping_filtered = build_investor_company_mapping(filtered_data) filtered_investors = list(investor_company_mapping_filtered.keys()) return filtered_investors, filtered_data # Generate Plotly graph # NEW: We add selected_valuation_range so we can check if the user selected 15-20 or 20+ def generate_graph(investors, filtered_data, selected_valuation_range): if not investors: logger.warning("No investors selected.") return go.Figure() # Create a color map for investors unique_investors = investors num_colors = len(unique_investors) color_palette = [ "#377eb8", "#e41a1c", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", "#a65628", "#f781bf", "#999999" ] while num_colors > len(color_palette): color_palette.extend(color_palette) investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)} G = nx.Graph() for investor in investors: companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist() for company in companies: G.add_node(company) G.add_node(investor) G.add_edge(investor, company) pos = nx.spring_layout(G, seed=1721, iterations=150) edge_x, edge_y = [], [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=0.5, color='#aaaaaa'), hoverinfo='none', mode='lines' ) node_x, node_y, node_text, node_textposition = [], [], [], [] node_color, node_size, node_hovertext = [], [], [] for node in G.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) if node in investors: node_text.append(node) # Add investor labels node_color.append(investor_color_map[node]) node_size.append(30) node_hovertext.append(f"Investor: {node}") node_textposition.append('top center') else: valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values size = valuation[0] * 5 if len(valuation) > 0 and not pd.isnull(valuation[0]) else 15 node_size.append(max(size, 10)) node_color.append("#a6d854") # Build the hover label text hovertext = f"Company: {node}" if len(industry) > 0 and not pd.isnull(industry[0]): hovertext += f"
Industry: {industry[0]}" if len(valuation) > 0 and not pd.isnull(valuation[0]): hovertext += f"
Valuation: ${valuation[0]:.2f}B" node_hovertext.append(hovertext) # NEW: If valuation range is 15–20 or 20+, show hovertext for all companies if selected_valuation_range in ["15-20", "20+"]: node_text.append(hovertext) # show full text node_textposition.append('bottom center') else: # Old logic: only show the company name in certain conditions if ( (len(valuation) > 0 and valuation[0] is not None and valuation[0] > 10) # Check if > 10B or (len(filtered_data) < 15) or (node in filtered_data.nlargest(5, "Valuation_Billions")["Company"].tolist()) ): node_text.append(node) # Show just the company name node_textposition.append('bottom center') else: node_text.append("") # Hide company label node_textposition.append('bottom center') node_trace = go.Scatter( x=node_x, y=node_y, text=node_text, textposition=node_textposition, mode='markers+text', hoverinfo='text', hovertext=node_hovertext, textfont=dict(size=13), # Adjust label font size marker=dict( showscale=False, size=node_size, color=node_color, line=dict(width=0.5, color='#333333') ) ) # Compute total market cap total_market_cap = filtered_data["Valuation_Billions"].sum() fig = go.Figure(data=[edge_trace, node_trace]) fig.update_layout( title="", titlefont_size=28, margin=dict(l=20, r=20, t=60, b=20), hovermode='closest', width=1200, height=800, autosize=True, xaxis=dict(showgrid=False, zeroline=False, visible=False), yaxis=dict(showgrid=False, zeroline=False, visible=False), showlegend=False, # Hide the legend to maximize space annotations=[ dict( x=0.5, y=1.1, xref='paper', yref='paper', text=f"Combined Market Cap: ${total_market_cap:.1f} Billions", showarrow=False, font=dict(size=14), xanchor='center' ) ] ) return fig # Gradio app function def app( selected_country, selected_industry, selected_company, selected_investors, exclude_countries, exclude_industries, exclude_companies, exclude_investors, selected_valuation_range ): investors, filtered_data = filter_investors( selected_country, selected_industry, selected_investors, selected_company, exclude_countries, exclude_industries, exclude_companies, exclude_investors, selected_valuation_range ) if not investors: return go.Figure() # NEW: Pass valuation_range to generate_graph graph = generate_graph(investors, filtered_data, selected_valuation_range) return graph def main(): country_list = ["All"] + sorted(data["Country"].dropna().unique()) industry_list = ["All"] + sorted(data["Industry"].dropna().unique()) company_list = ["All"] + sorted(data["Company"].dropna().unique()) investor_list = sorted(investor_company_mapping.keys()) # Valuation range choices valuation_ranges = ["All", "1-5", "5-10", "10-15", "15-20", "20+"] with gr.Blocks(title="Venture Networks Visualization") as demo: gr.Markdown("# Venture Networks Visualization") with gr.Row(): country_filter = gr.Dropdown(choices=country_list, label="Country", value="All") industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All") company_filter = gr.Dropdown(choices=company_list, label="Company", value="All") investor_filter = gr.Dropdown(choices=investor_list, label="Select Investors", value=[], multiselect=True) with gr.Row(): valuation_range_filter = gr.Dropdown( choices=valuation_ranges, label="Valuation Range (Billions)", value="All" ) exclude_country_filter = gr.Dropdown(choices=country_list[1:], label="Exclude Country", value=[], multiselect=True) exclude_industry_filter = gr.Dropdown(choices=industry_list[1:], label="Exclude Industry", value=[], multiselect=True) exclude_company_filter = gr.Dropdown(choices=company_list[1:], label="Exclude Company", value=[], multiselect=True) exclude_investor_filter = gr.Dropdown(choices=investor_list, label="Exclude Investors", value=[], multiselect=True) graph_output = gr.Plot(label="Network Graph") inputs = [ country_filter, industry_filter, company_filter, investor_filter, exclude_country_filter, exclude_industry_filter, exclude_company_filter, exclude_investor_filter, valuation_range_filter ] outputs = [graph_output] # Set up event triggers for all inputs for input_control in inputs: input_control.change(app, inputs, outputs) gr.Markdown("**Instructions:** Use the dropdowns to filter the network graph. For valuation ranges 15–20 or 20+, you’ll see each company's info label without hovering.") gr.Markdown("**Note:** All companies are in green, while venture firms have different colors. The diameter of the company circle varies proportionate to the valuation.") demo.launch() if __name__ == "__main__": main()