networkx-saas / app.py
LeonceNsh's picture
Update app.py
cf8a69c verified
raw
history blame
6.5 kB
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import gradio as gr
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load and preprocess the dataset
file_path = "cbinsights_data.csv" # Replace with your actual file path
try:
data = pd.read_csv(file_path, skiprows=1)
logger.info("CSV file loaded successfully.")
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
raise
except Exception as e:
logger.error(f"Error loading CSV file: {e}")
raise
# Standardize column names
data.columns = data.columns.str.strip().str.lower()
logger.info(f"Standardized Column Names: {data.columns.tolist()}")
# Identify the valuation column
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
if len(valuation_columns) != 1:
logger.error("Unable to identify a single valuation column.")
raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
valuation_column = valuation_columns[0]
logger.info(f"Using valuation column: {valuation_column}")
# Clean and prepare data
data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
data.rename(columns={
"company": "Company",
"valuation_billions": "Valuation_Billions",
"date_joined": "Date_Joined",
"country": "Country",
"city": "City",
"industry": "Industry",
"select_investors": "Select_Investors"
}, inplace=True)
logger.info("Data cleaned and columns renamed.")
# Build investor-company mapping
def build_investor_company_mapping(df):
mapping = {}
for _, row in df.iterrows():
company = row["Company"]
investors = row["Select_Investors"]
if pd.notnull(investors):
for investor in investors.split(","):
investor = investor.strip()
if investor:
mapping.setdefault(investor, []).append(company)
return mapping
investor_company_mapping = build_investor_company_mapping(data)
logger.info("Investor to company mapping created.")
# Filter investors by country, industry, and valuation threshold
def filter_investors(selected_country, selected_industry, valuation_threshold):
filtered_data = data.copy()
if selected_country != "All":
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
if selected_industry != "All":
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
investor_valuations = {
investor: filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
for investor, companies in investor_company_mapping_filtered.items()
}
filtered_investors = [investor for investor, total in investor_valuations.items() if total >= valuation_threshold]
return filtered_investors, filtered_data
# Generate Plotly graph
def generate_graph(investors, filtered_data):
if not investors:
logger.warning("No investors selected.")
return go.Figure()
G = nx.Graph()
for investor in investors:
companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
for company in companies:
G.add_edge(investor, company)
pos = nx.spring_layout(G, seed=42)
edge_x = []
edge_y = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x,
y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines'
)
node_x = []
node_y = []
node_text = []
node_color = []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_text.append(node)
if node in investors:
node_color.append(20) # Fixed color value for investors
else:
valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
node_color.append(valuation)
node_trace = go.Scatter(
x=node_x,
y=node_y,
text=node_text,
mode='markers',
hoverinfo='text',
marker=dict(
showscale=True,
colorscale='YlGnBu',
size=10,
color=node_color,
colorbar=dict(
thickness=15,
title="Valuation (B)",
xanchor='left',
titleside='right'
)
)
)
fig = go.Figure(data=[edge_trace, node_trace])
fig.update_layout(
showlegend=False,
title="Venture Networks",
titlefont_size=16,
margin=dict(l=40, r=40, t=40, b=40),
hovermode='closest'
)
return fig
# Gradio app
def app(selected_country, selected_industry, valuation_threshold):
investors, filtered_data = filter_investors(selected_country, selected_industry, valuation_threshold)
graph = generate_graph(investors, filtered_data)
return investors, graph
def main():
country_list = ["All"] + sorted(data["Country"].dropna().unique())
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
with gr.Blocks() as demo:
with gr.Row():
country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
valuation_slider = gr.Slider(0, 50, value=20, step=1, label="Valuation Threshold (B)")
investor_output = gr.Textbox(label="Filtered Investors")
graph_output = gr.Plot(label="Network Graph")
country_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
industry_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
valuation_slider.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
demo.launch()
if __name__ == "__main__":
main()