import pandas as pd import networkx as nx import plotly.graph_objects as go import gradio as gr import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load and preprocess the dataset file_path = "cbinsights_data.csv" # Replace with your actual file path try: data = pd.read_csv(file_path, skiprows=1) logger.info("CSV file loaded successfully.") except FileNotFoundError: logger.error(f"File not found: {file_path}") raise except Exception as e: logger.error(f"Error loading CSV file: {e}") raise # Standardize column names data.columns = data.columns.str.strip().str.lower() logger.info(f"Standardized Column Names: {data.columns.tolist()}") # Identify the valuation column valuation_columns = [col for col in data.columns if 'valuation' in col.lower()] if len(valuation_columns) != 1: logger.error("Unable to identify a single valuation column.") raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.") valuation_column = valuation_columns[0] logger.info(f"Using valuation column: {valuation_column}") # Clean and prepare data data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True) data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce') data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col) data.rename(columns={ "company": "Company", "valuation_billions": "Valuation_Billions", "date_joined": "Date_Joined", "country": "Country", "city": "City", "industry": "Industry", "select_investors": "Select_Investors" }, inplace=True) logger.info("Data cleaned and columns renamed.") # Build investor-company mapping def build_investor_company_mapping(df): mapping = {} for _, row in df.iterrows(): company = row["Company"] investors = row["Select_Investors"] if pd.notnull(investors): for investor in investors.split(","): investor = investor.strip() if investor: mapping.setdefault(investor, []).append(company) return mapping investor_company_mapping = build_investor_company_mapping(data) logger.info("Investor to company mapping created.") # Filter investors by country, industry, and valuation threshold def filter_investors(selected_country, selected_industry, valuation_threshold): filtered_data = data.copy() if selected_country != "All": filtered_data = filtered_data[filtered_data["Country"] == selected_country] if selected_industry != "All": filtered_data = filtered_data[filtered_data["Industry"] == selected_industry] investor_company_mapping_filtered = build_investor_company_mapping(filtered_data) investor_valuations = { investor: filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum() for investor, companies in investor_company_mapping_filtered.items() } filtered_investors = [investor for investor, total in investor_valuations.items() if total >= valuation_threshold] return filtered_investors, filtered_data # Generate Plotly graph def generate_graph(investors, filtered_data): if not investors: logger.warning("No investors selected.") return go.Figure() G = nx.Graph() for investor in investors: companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist() for company in companies: G.add_edge(investor, company) pos = nx.spring_layout(G, seed=42) edge_x = [] edge_y = [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines' ) node_x = [] node_y = [] node_text = [] node_color = [] for node in G.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) node_text.append(node) if node in investors: node_color.append(20) # Fixed color value for investors else: valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum() node_color.append(valuation) node_trace = go.Scatter( x=node_x, y=node_y, text=node_text, mode='markers', hoverinfo='text', marker=dict( showscale=True, colorscale='YlGnBu', size=10, color=node_color, colorbar=dict( thickness=15, title="Valuation (B)", xanchor='left', titleside='right' ) ) ) fig = go.Figure(data=[edge_trace, node_trace]) fig.update_layout( showlegend=False, title="Venture Networks", titlefont_size=16, margin=dict(l=40, r=40, t=40, b=40), hovermode='closest' ) return fig # Gradio app def app(selected_country, selected_industry, valuation_threshold): investors, filtered_data = filter_investors(selected_country, selected_industry, valuation_threshold) graph = generate_graph(investors, filtered_data) return investors, graph def main(): country_list = ["All"] + sorted(data["Country"].dropna().unique()) industry_list = ["All"] + sorted(data["Industry"].dropna().unique()) with gr.Blocks() as demo: with gr.Row(): country_filter = gr.Dropdown(choices=country_list, label="Country", value="All") industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All") valuation_slider = gr.Slider(0, 50, value=20, step=1, label="Valuation Threshold (B)") investor_output = gr.Textbox(label="Filtered Investors") graph_output = gr.Plot(label="Network Graph") country_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output]) industry_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output]) valuation_slider.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output]) demo.launch() if __name__ == "__main__": main()