networkx-saas / app.py
LeonceNsh's picture
Update app.py
eedc3a8 verified
raw
history blame
6.65 kB
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import gradio as gr
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load and preprocess the dataset
file_path = "cbinsights_data.csv" # Replace with your actual file path
try:
data = pd.read_csv(file_path, skiprows=1)
logger.info("CSV file loaded successfully.")
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
raise
except Exception as e:
logger.error(f"Error loading CSV file: {e}")
raise
# Standardize column names
data.columns = data.columns.str.strip().str.lower()
logger.info(f"Standardized Column Names: {data.columns.tolist()}")
# Clean and prepare data
data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
data.rename(columns={
"company": "Company",
"valuation": "Valuation",
"date_joined": "Date_Joined",
"country": "Country",
"city": "City",
"industry": "Industry",
"select_investors": "Select_Investors"
}, inplace=True)
# Convert valuation to numeric for proportional node sizing
data["Valuation"] = pd.to_numeric(
data["Valuation"].replace({"\$": "", ",": ""}, regex=True), errors="coerce"
)
logger.info("Data cleaned and columns renamed.")
# Build investor-company mapping
def build_investor_company_mapping(df):
mapping = {}
for _, row in df.iterrows():
company = row["Company"]
investors = row["Select_Investors"]
if pd.notnull(investors):
for investor in investors.split(","):
investor = investor.strip()
if investor:
mapping.setdefault(investor, []).append(company)
return mapping
investor_company_mapping = build_investor_company_mapping(data)
logger.info("Investor to company mapping created.")
# Filter investors by country, industry, and investor selection
def filter_investors(selected_country, selected_industry, selected_investors):
filtered_data = data.copy()
if selected_country != "All":
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
if selected_industry != "All":
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
if selected_investors != ["All"]:
filtered_data = filtered_data[
filtered_data["Select_Investors"].apply(
lambda x: any(inv in x for inv in selected_investors) if pd.notnull(x) else False
)
]
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
filtered_investors = list(investor_company_mapping_filtered.keys())
return filtered_investors, filtered_data
# Generate Plotly graph
def generate_graph(investors, filtered_data):
if not investors:
logger.warning("No investors selected.")
return go.Figure()
G = nx.Graph()
for investor in investors:
companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
for company in companies:
G.add_edge(investor, company)
pos = nx.spring_layout(G, seed=42)
edge_x = []
edge_y = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x,
y=edge_y,
line=dict(width=1, color="#888"),
hoverinfo="none",
mode="lines"
)
node_x = []
node_y = []
node_text = []
node_color = []
node_size = []
# Color palette for investors (color blind friendly)
investor_colors = [
"#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7"
]
investor_color_map = {investor: investor_colors[i % len(investor_colors)] for i, investor in enumerate(investors)}
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
if node in investors:
node_text.append(node) # Label investors
node_color.append(investor_color_map[node]) # Assign distinct colors to investors
node_size.append(20) # Fixed size for investor nodes
else:
valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation"].sum()
node_text.append("") # Hide company labels by default
node_color.append("lightgreen") # Light green for companies
node_size.append(max(10, valuation / 100)) # Size proportional to valuation
node_trace = go.Scatter(
x=node_x,
y=node_y,
text=node_text,
mode="markers",
hoverinfo="text",
marker=dict(
showscale=False,
size=node_size,
color=node_color,
)
)
fig = go.Figure(data=[edge_trace, node_trace])
fig.update_layout(
showlegend=False,
title="Venture Networks",
titlefont_size=20,
margin=dict(l=20, r=20, t=50, b=20),
hovermode="closest",
width=1200,
height=800
)
return fig
# Gradio app
def app(selected_country, selected_industry, selected_investors):
investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
graph = generate_graph(investors, filtered_data)
return investors, graph
# Main function
def main():
country_list = ["All"] + sorted(data["Country"].dropna().unique())
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
investor_list = ["All"] + sorted(investor_company_mapping.keys())
with gr.Blocks() as demo:
with gr.Row():
country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
investor_filter = gr.CheckboxGroup(choices=investor_list, label="Investors", value=["All"])
investor_output = gr.Textbox(label="Filtered Investors")
graph_output = gr.Plot(label="Network Graph")
country_filter.change(
app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
)
industry_filter.change(
app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
)
investor_filter.change(
app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
)
demo.launch()
if __name__ == "__main__":
main()