Spaces:
Running
Running
import pandas as pd | |
import networkx as nx | |
import plotly.graph_objects as go | |
import gradio as gr | |
import re | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load and preprocess the dataset | |
file_path = "cbinsights_data.csv" # Replace with your actual file path | |
try: | |
data = pd.read_csv(file_path, skiprows=1) | |
logger.info("CSV file loaded successfully.") | |
except FileNotFoundError: | |
logger.error(f"File not found: {file_path}") | |
raise | |
except Exception as e: | |
logger.error(f"Error loading CSV file: {e}") | |
raise | |
# Standardize column names | |
data.columns = data.columns.str.strip().str.lower() | |
logger.info(f"Standardized Column Names: {data.columns.tolist()}") | |
# Filter out 'Health' since 'Healthcare' is the correct Market Segment | |
data = data[data.industry != 'Health'] | |
# Identify the valuation column | |
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()] | |
if len(valuation_columns) != 1: | |
logger.error("Unable to identify a single valuation column.") | |
raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.") | |
valuation_column = valuation_columns[0] | |
logger.info(f"Using valuation column: {valuation_column}") | |
# Clean and prepare data | |
data["valuation_billions"] = data[valuation_column].apply( | |
lambda x: float(re.sub(r"[^0-9.]", "", str(x))) / 1e9 if pd.notnull(x) else 0 | |
) | |
logger.info("Valuation column cleaned and converted to billions.") | |
# Create a graph | |
G = nx.Graph() | |
for _, row in data.iterrows(): | |
company_name = row["company"] | |
valuation = row["valuation_billions"] | |
industry = row["industry"] | |
# Add company node | |
G.add_node( | |
company_name, | |
size=valuation, | |
color="green" if industry == "Venture Firm" else "blue", | |
) | |
# Add connections based on relationships (assume relationships column exists) | |
if "relationships" in data.columns: | |
relationships = str(row["relationships"]).split(";") | |
for relation in relationships: | |
G.add_edge(company_name, relation.strip()) | |
# Create Plotly visualization | |
node_sizes = [G.nodes[node]["size"] * 50 for node in G.nodes] | |
node_colors = [G.nodes[node]["color"] for node in G.nodes] | |
pos = nx.spring_layout(G) | |
x_coords = [pos[node][0] for node in G.nodes] | |
y_coords = [pos[node][1] for node in G.nodes] | |
fig = go.Figure() | |
fig.add_trace( | |
go.Scatter( | |
x=x_coords, | |
y=y_coords, | |
mode="markers+text", | |
marker=dict(size=node_sizes, color=node_colors, opacity=0.8), | |
text=list(G.nodes), | |
textposition="top center", | |
) | |
) | |
fig.update_layout( | |
title="Company Network Visualization", | |
xaxis=dict(showgrid=False, zeroline=False), | |
yaxis=dict(showgrid=False, zeroline=False), | |
showlegend=False, | |
) | |
# Note: All companies are in green while venture firms have different colors. | |
# The diameter of the company circle varies proportionate to the valuation of the corresponding company. | |
# Create Gradio interface | |
def display_network(): | |
return fig.to_html() | |
gr.Interface( | |
fn=display_network, | |
inputs=[], | |
outputs="html", | |
title="Company Network Visualization", | |
).launch() | |