Spaces:
Sleeping
Sleeping
import pandas as pd | |
import networkx as nx | |
import plotly.graph_objects as go | |
import gradio as gr | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load and preprocess the dataset | |
file_path = "cbinsights_data.csv" # Replace with your actual file path | |
try: | |
data = pd.read_csv(file_path, skiprows=1) | |
logger.info("CSV file loaded successfully.") | |
except FileNotFoundError: | |
logger.error(f"File not found: {file_path}") | |
raise | |
except Exception as e: | |
logger.error(f"Error loading CSV file: {e}") | |
raise | |
# Standardize column names | |
data.columns = data.columns.str.strip().str.lower() | |
logger.info(f"Standardized Column Names: {data.columns.tolist()}") | |
# Identify the valuation column | |
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()] | |
if len(valuation_columns) != 1: | |
logger.error("Unable to identify a single valuation column.") | |
raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.") | |
valuation_column = valuation_columns[0] | |
logger.info(f"Using valuation column: {valuation_column}") | |
# Clean and prepare data | |
data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True) | |
data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce') | |
data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col) | |
data.rename(columns={ | |
"company": "Company", | |
"valuation_billions": "Valuation_Billions", | |
"date_joined": "Date_Joined", | |
"country": "Country", | |
"city": "City", | |
"industry": "Industry", | |
"select_investors": "Select_Investors" | |
}, inplace=True) | |
logger.info("Data cleaned and columns renamed.") | |
# Build investor-company mapping | |
def build_investor_company_mapping(df): | |
mapping = {} | |
for _, row in df.iterrows(): | |
company = row["Company"] | |
investors = row["Select_Investors"] | |
if pd.notnull(investors): | |
for investor in investors.split(","): | |
investor = investor.strip() | |
if investor: | |
mapping.setdefault(investor, []).append(company) | |
return mapping | |
investor_company_mapping = build_investor_company_mapping(data) | |
logger.info("Investor to company mapping created.") | |
# Filter investors by country, industry, and valuation threshold | |
def filter_investors(selected_country, selected_industry, valuation_threshold): | |
filtered_data = data.copy() | |
if selected_country != "All": | |
filtered_data = filtered_data[filtered_data["Country"] == selected_country] | |
if selected_industry != "All": | |
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry] | |
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data) | |
investor_valuations = { | |
investor: filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum() | |
for investor, companies in investor_company_mapping_filtered.items() | |
} | |
filtered_investors = [investor for investor, total in investor_valuations.items() if total >= valuation_threshold] | |
return filtered_investors, filtered_data | |
# Generate Plotly graph | |
def generate_graph(investors, filtered_data): | |
if not investors: | |
logger.warning("No investors selected.") | |
return go.Figure() | |
G = nx.Graph() | |
for investor in investors: | |
companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist() | |
for company in companies: | |
G.add_edge(investor, company) | |
pos = nx.spring_layout(G, seed=42) | |
edge_x = [] | |
edge_y = [] | |
for edge in G.edges(): | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
edge_x.extend([x0, x1, None]) | |
edge_y.extend([y0, y1, None]) | |
edge_trace = go.Scatter( | |
x=edge_x, | |
y=edge_y, | |
line=dict(width=0.5, color='#888'), | |
hoverinfo='none', | |
mode='lines' | |
) | |
node_x = [] | |
node_y = [] | |
node_text = [] | |
node_color = [] | |
for node in G.nodes(): | |
x, y = pos[node] | |
node_x.append(x) | |
node_y.append(y) | |
node_text.append(node) | |
if node in investors: | |
node_color.append(20) # Fixed color value for investors | |
else: | |
valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum() | |
node_color.append(valuation) | |
node_trace = go.Scatter( | |
x=node_x, | |
y=node_y, | |
text=node_text, | |
mode='markers', | |
hoverinfo='text', | |
marker=dict( | |
showscale=True, | |
colorscale='YlGnBu', | |
size=10, | |
color=node_color, | |
colorbar=dict( | |
thickness=15, | |
title="Valuation (B)", | |
xanchor='left', | |
titleside='right' | |
) | |
) | |
) | |
fig = go.Figure(data=[edge_trace, node_trace]) | |
fig.update_layout( | |
showlegend=False, | |
title="Venture Networks", | |
titlefont_size=16, | |
margin=dict(l=40, r=40, t=40, b=40), | |
hovermode='closest' | |
) | |
return fig | |
# Gradio app | |
def app(selected_country, selected_industry, valuation_threshold): | |
investors, filtered_data = filter_investors(selected_country, selected_industry, valuation_threshold) | |
graph = generate_graph(investors, filtered_data) | |
return investors, graph | |
def main(): | |
country_list = ["All"] + sorted(data["Country"].dropna().unique()) | |
industry_list = ["All"] + sorted(data["Industry"].dropna().unique()) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
country_filter = gr.Dropdown(choices=country_list, label="Country", value="All") | |
industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All") | |
valuation_slider = gr.Slider(0, 50, value=20, step=1, label="Valuation Threshold (B)") | |
investor_output = gr.Textbox(label="Filtered Investors") | |
graph_output = gr.Plot(label="Network Graph") | |
country_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output]) | |
industry_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output]) | |
valuation_slider.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output]) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |