networkx-saas / app.py
LeonceNsh's picture
Update app.py
cc1818e verified
raw
history blame
5.15 kB
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import gradio as gr
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load and preprocess the dataset
file_path = "cbinsights_data.csv" # Replace with your actual file path
try:
data = pd.read_csv(file_path, skiprows=1)
logger.info("CSV file loaded successfully.")
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
raise
except Exception as e:
logger.error(f"Error loading CSV file: {e}")
raise
# Standardize column names
data.columns = data.columns.str.strip().str.lower()
logger.info(f"Standardized Column Names: {data.columns.tolist()}")
# Clean and prepare data
data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
data.rename(columns={
"company": "Company",
"date_joined": "Date_Joined",
"country": "Country",
"city": "City",
"industry": "Industry",
"select_investors": "Select_Investors"
}, inplace=True)
logger.info("Data cleaned and columns renamed.")
# Build investor-company mapping
def build_investor_company_mapping(df):
mapping = {}
for _, row in df.iterrows():
company = row["Company"]
investors = row["Select_Investors"]
if pd.notnull(investors):
for investor in investors.split(","):
investor = investor.strip()
if investor:
mapping.setdefault(investor, []).append(company)
return mapping
investor_company_mapping = build_investor_company_mapping(data)
logger.info("Investor to company mapping created.")
# Filter investors by country and industry (removed valuation threshold)
def filter_investors(selected_country, selected_industry):
filtered_data = data.copy()
if selected_country != "All":
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
if selected_industry != "All":
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
filtered_investors = list(investor_company_mapping_filtered.keys())
return filtered_investors, filtered_data
# Generate Plotly graph with increased size and improved styling
def generate_graph(investors, filtered_data):
if not investors:
logger.warning("No investors selected.")
return go.Figure()
G = nx.Graph()
for investor in investors:
companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
for company in companies:
G.add_edge(investor, company)
pos = nx.spring_layout(G, seed=42)
edge_x = []
edge_y = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x,
y=edge_y,
line=dict(width=1, color='#888'),
hoverinfo='none',
mode='lines'
)
node_x = []
node_y = []
node_text = []
node_color = []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_text.append(node)
node_color.append(10) # Use a fixed color or other logic
node_trace = go.Scatter(
x=node_x,
y=node_y,
text=node_text,
mode='markers+text',
hoverinfo='text',
marker=dict(
showscale=False,
size=15, # Increased size
color=node_color,
),
textposition="top center" # Improved label positioning
)
fig = go.Figure(data=[edge_trace, node_trace])
fig.update_layout(
showlegend=False,
title="Venture Networks",
titlefont_size=20,
margin=dict(l=20, r=20, t=50, b=20),
hovermode='closest',
width=1200, # Increased width
height=800 # Increased height
)
return fig
# Update the Gradio app to remove valuation threshold
def app(selected_country, selected_industry):
investors, filtered_data = filter_investors(selected_country, selected_industry)
graph = generate_graph(investors, filtered_data)
return investors, graph
# Main function
def main():
country_list = ["All"] + sorted(data["Country"].dropna().unique())
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
with gr.Blocks() as demo:
with gr.Row():
country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
investor_output = gr.Textbox(label="Filtered Investors")
graph_output = gr.Plot(label="Network Graph")
country_filter.change(app, [country_filter, industry_filter], [investor_output, graph_output])
industry_filter.change(app, [country_filter, industry_filter], [investor_output, graph_output])
demo.launch()
if __name__ == "__main__":
main()