import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import gradio as gr
import re
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load and preprocess the dataset
file_path = "cbinsights_data.csv" # Replace with your actual file path
try:
data = pd.read_csv(file_path, skiprows=1)
logger.info("CSV file loaded successfully.")
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
raise
except Exception as e:
logger.error(f"Error loading CSV file: {e}")
raise
# Standardize column names
data.columns = data.columns.str.strip().str.lower()
logger.info(f"Standardized Column Names: {data.columns.tolist()}")
# Filter out Health since Healthcare is the correct Market Segment
data = data[data.industry != 'Health']
# Identify the valuation column
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
if len(valuation_columns) != 1:
logger.error("Unable to identify a single valuation column.")
raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
valuation_column = valuation_columns[0]
logger.info(f"Using valuation column: {valuation_column}")
# Clean and prepare data
data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
data.rename(columns={
"company": "Company",
"date_joined": "Date_Joined",
"country": "Country",
"city": "City",
"industry": "Industry",
"select_investors": "Select_Investors"
}, inplace=True)
logger.info("Data cleaned and columns renamed.")
# Build investor-company mapping
def build_investor_company_mapping(df):
mapping = {}
for _, row in df.iterrows():
company = row["Company"]
investors = row["Select_Investors"]
if pd.notnull(investors):
for investor in investors.split(","):
investor = investor.strip()
if investor:
mapping.setdefault(investor, []).append(company)
return mapping
investor_company_mapping = build_investor_company_mapping(data)
logger.info("Investor to company mapping created.")
# -------------------------
# Valuation-Range Logic
# -------------------------
def filter_by_valuation_range(df, selected_valuation_range):
"""Filter dataframe by the specified valuation range in billions."""
if selected_valuation_range == "All":
return df # No further filtering
if selected_valuation_range == "1-5":
return df[(df["Valuation_Billions"] >= 1) & (df["Valuation_Billions"] < 5)]
elif selected_valuation_range == "5-10":
return df[(df["Valuation_Billions"] >= 5) & (df["Valuation_Billions"] < 10)]
elif selected_valuation_range == "10-15":
return df[(df["Valuation_Billions"] >= 10) & (df["Valuation_Billions"] < 15)]
elif selected_valuation_range == "15-20":
return df[(df["Valuation_Billions"] >= 15) & (df["Valuation_Billions"] < 20)]
elif selected_valuation_range == "20+":
return df[df["Valuation_Billions"] >= 20]
else:
return df # Fallback, should never happen
# Filter investors by country, industry, investor selection, company selection, and valuation range
def filter_investors(
selected_country,
selected_industry,
selected_investors,
selected_company,
exclude_countries,
exclude_industries,
exclude_companies,
exclude_investors,
selected_valuation_range
):
filtered_data = data.copy()
# 1) Valuation range filter
filtered_data = filter_by_valuation_range(filtered_data, selected_valuation_range)
# 2) Now apply the existing filters:
# Inclusion filters
if selected_country != "All":
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
if selected_industry != "All":
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
if selected_company != "All":
filtered_data = filtered_data[filtered_data["Company"] == selected_company]
if selected_investors:
pattern = '|'.join([re.escape(inv) for inv in selected_investors])
filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
# Exclusion filters
if exclude_countries:
filtered_data = filtered_data[~filtered_data["Country"].isin(exclude_countries)]
if exclude_industries:
filtered_data = filtered_data[~filtered_data["Industry"].isin(exclude_industries)]
if exclude_companies:
filtered_data = filtered_data[~filtered_data["Company"].isin(exclude_companies)]
if exclude_investors:
pattern = '|'.join([re.escape(inv) for inv in exclude_investors])
filtered_data = filtered_data[~filtered_data["Select_Investors"].str.contains(pattern, na=False)]
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
filtered_investors = list(investor_company_mapping_filtered.keys())
return filtered_investors, filtered_data
# Generate Plotly graph
# NEW: We add selected_valuation_range so we can check if the user selected 15-20 or 20+
def generate_graph(investors, filtered_data, selected_valuation_range):
if not investors:
logger.warning("No investors selected.")
return go.Figure()
# Create a color map for investors
unique_investors = investors
num_colors = len(unique_investors)
color_palette = [
"#377eb8", "#e41a1c", "#4daf4a", "#984ea3",
"#ff7f00", "#ffff33", "#a65628", "#f781bf", "#999999"
]
while num_colors > len(color_palette):
color_palette.extend(color_palette)
investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
G = nx.Graph()
for investor in investors:
companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
for company in companies:
G.add_node(company)
G.add_node(investor)
G.add_edge(investor, company)
pos = nx.spring_layout(G, seed=1721, iterations=150)
edge_x, edge_y = [], []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x,
y=edge_y,
line=dict(width=0.5, color='#aaaaaa'),
hoverinfo='none',
mode='lines'
)
node_x, node_y, node_text, node_textposition = [], [], [], []
node_color, node_size, node_hovertext = [], [], []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
if node in investors:
node_text.append(node) # Add investor labels
node_color.append(investor_color_map[node])
node_size.append(30)
node_hovertext.append(f"Investor: {node}")
node_textposition.append('top center')
else:
valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
industry = filtered_data.loc[filtered_data["Company"] == node, "Industry"].values
size = valuation[0] * 5 if len(valuation) > 0 and not pd.isnull(valuation[0]) else 15
node_size.append(max(size, 10))
node_color.append("#a6d854")
# Build the hover label text
hovertext = f"Company: {node}"
if len(industry) > 0 and not pd.isnull(industry[0]):
hovertext += f"
Industry: {industry[0]}"
if len(valuation) > 0 and not pd.isnull(valuation[0]):
hovertext += f"
Valuation: ${valuation[0]:.2f}B"
node_hovertext.append(hovertext)
# NEW: If valuation range is 15–20 or 20+, show hovertext for all companies
if selected_valuation_range in ["15-20", "20+"]:
node_text.append(hovertext) # show full text
node_textposition.append('bottom center')
else:
# Old logic: only show the company name in certain conditions
if (
(len(valuation) > 0 and valuation[0] is not None and valuation[0] > 10) # Check if > 10B
or (len(filtered_data) < 15)
or (node in filtered_data.nlargest(5, "Valuation_Billions")["Company"].tolist())
):
node_text.append(node) # Show just the company name
node_textposition.append('bottom center')
else:
node_text.append("") # Hide company label
node_textposition.append('bottom center')
node_trace = go.Scatter(
x=node_x,
y=node_y,
text=node_text,
textposition=node_textposition,
mode='markers+text',
hoverinfo='text',
hovertext=node_hovertext,
textfont=dict(size=13), # Adjust label font size
marker=dict(
showscale=False,
size=node_size,
color=node_color,
line=dict(width=0.5, color='#333333')
)
)
# Compute total market cap
total_market_cap = filtered_data["Valuation_Billions"].sum()
fig = go.Figure(data=[edge_trace, node_trace])
fig.update_layout(
title="",
titlefont_size=28,
margin=dict(l=20, r=20, t=60, b=20),
hovermode='closest',
width=1200,
height=800,
autosize=True,
xaxis=dict(showgrid=False, zeroline=False, visible=False),
yaxis=dict(showgrid=False, zeroline=False, visible=False),
showlegend=False, # Hide the legend to maximize space
annotations=[
dict(
x=0.5, y=1.1, xref='paper', yref='paper',
text=f"Combined Market Cap: ${total_market_cap:.1f} Billions",
showarrow=False, font=dict(size=14), xanchor='center'
)
]
)
return fig
# Gradio app function
def app(
selected_country,
selected_industry,
selected_company,
selected_investors,
exclude_countries,
exclude_industries,
exclude_companies,
exclude_investors,
selected_valuation_range
):
investors, filtered_data = filter_investors(
selected_country,
selected_industry,
selected_investors,
selected_company,
exclude_countries,
exclude_industries,
exclude_companies,
exclude_investors,
selected_valuation_range
)
if not investors:
return go.Figure()
# NEW: Pass valuation_range to generate_graph
graph = generate_graph(investors, filtered_data, selected_valuation_range)
return graph
def main():
country_list = ["All"] + sorted(data["Country"].dropna().unique())
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
company_list = ["All"] + sorted(data["Company"].dropna().unique())
investor_list = sorted(investor_company_mapping.keys())
# Valuation range choices
valuation_ranges = ["All", "1-5", "5-10", "10-15", "15-20", "20+"]
with gr.Blocks(title="Venture Networks Visualization") as demo:
gr.Markdown("# Venture Networks Visualization")
with gr.Row():
country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
company_filter = gr.Dropdown(choices=company_list, label="Company", value="All")
investor_filter = gr.Dropdown(choices=investor_list, label="Select Investors", value=[], multiselect=True)
with gr.Row():
valuation_range_filter = gr.Dropdown(
choices=valuation_ranges,
label="Valuation Range (Billions)",
value="All"
)
exclude_country_filter = gr.Dropdown(choices=country_list[1:], label="Exclude Country", value=[], multiselect=True)
exclude_industry_filter = gr.Dropdown(choices=industry_list[1:], label="Exclude Industry", value=[], multiselect=True)
exclude_company_filter = gr.Dropdown(choices=company_list[1:], label="Exclude Company", value=[], multiselect=True)
exclude_investor_filter = gr.Dropdown(choices=investor_list, label="Exclude Investors", value=[], multiselect=True)
graph_output = gr.Plot(label="Network Graph")
inputs = [
country_filter,
industry_filter,
company_filter,
investor_filter,
exclude_country_filter,
exclude_industry_filter,
exclude_company_filter,
exclude_investor_filter,
valuation_range_filter
]
outputs = [graph_output]
# Set up event triggers for all inputs
for input_control in inputs:
input_control.change(app, inputs, outputs)
gr.Markdown("**Instructions:** Use the dropdowns to filter the network graph. For valuation ranges 15–20 or 20+, you’ll see each company's info label without hovering.")
gr.Markdown("**Note:** All companies are in green, while venture firms have different colors. The diameter of the company circle varies proportionate to the valuation.")
demo.launch()
if __name__ == "__main__":
main()