Spaces:
Running
Running
File size: 3,202 Bytes
5c05d7a 2f36052 5c05d7a 2f36052 5c05d7a 2f36052 5c05d7a e6abf6e 5c05d7a e6abf6e 5c05d7a 9e2bc99 5c05d7a 2f36052 04d2663 05d82ce dd5783d 5c05d7a 04d2663 5240604 04d2663 5240604 e830460 5240604 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import gradio as gr
import re
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load and preprocess the dataset
file_path = "cbinsights_data.csv" # Replace with your actual file path
try:
data = pd.read_csv(file_path, skiprows=1)
logger.info("CSV file loaded successfully.")
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
raise
except Exception as e:
logger.error(f"Error loading CSV file: {e}")
raise
# Standardize column names
data.columns = data.columns.str.strip().str.lower()
logger.info(f"Standardized Column Names: {data.columns.tolist()}")
# Filter out 'Health' since 'Healthcare' is the correct Market Segment
data = data[data.industry != 'Health']
# Identify the valuation column
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
if len(valuation_columns) != 1:
logger.error("Unable to identify a single valuation column.")
raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
valuation_column = valuation_columns[0]
logger.info(f"Using valuation column: {valuation_column}")
# Clean and prepare data
data["valuation_billions"] = data[valuation_column].apply(
lambda x: float(re.sub(r"[^0-9.]", "", str(x))) / 1e9 if pd.notnull(x) else 0
)
logger.info("Valuation column cleaned and converted to billions.")
# Create a graph
G = nx.Graph()
for _, row in data.iterrows():
company_name = row["company"]
valuation = row["valuation_billions"]
industry = row["industry"]
# Add company node
G.add_node(
company_name,
size=valuation,
color="green" if industry == "Venture Firm" else "blue",
)
# Add connections based on relationships (assume relationships column exists)
if "relationships" in data.columns:
relationships = str(row["relationships"]).split(";")
for relation in relationships:
G.add_edge(company_name, relation.strip())
# Create Plotly visualization
node_sizes = [G.nodes[node]["size"] * 50 for node in G.nodes]
node_colors = [G.nodes[node]["color"] for node in G.nodes]
pos = nx.spring_layout(G)
x_coords = [pos[node][0] for node in G.nodes]
y_coords = [pos[node][1] for node in G.nodes]
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x_coords,
y=y_coords,
mode="markers+text",
marker=dict(size=node_sizes, color=node_colors, opacity=0.8),
text=list(G.nodes),
textposition="top center",
)
)
fig.update_layout(
title="Company Network Visualization",
xaxis=dict(showgrid=False, zeroline=False),
yaxis=dict(showgrid=False, zeroline=False),
showlegend=False,
)
# Note: All companies are in green while venture firms have different colors.
# The diameter of the company circle varies proportionate to the valuation of the corresponding company.
# Create Gradio interface
def display_network():
return fig.to_html()
gr.Interface(
fn=display_network,
inputs=[],
outputs="html",
title="Company Network Visualization",
).launch()
|