File size: 3,202 Bytes
5c05d7a
2f36052
5c05d7a
2f36052
5c05d7a
 
2f36052
5c05d7a
 
 
e6abf6e
5c05d7a
 
e6abf6e
5c05d7a
 
 
 
 
 
 
 
 
9e2bc99
5c05d7a
 
 
2f36052
04d2663
05d82ce
dd5783d
5c05d7a
 
 
 
 
 
 
 
 
 
04d2663
5240604
04d2663
5240604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e830460
5240604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import gradio as gr
import re
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load and preprocess the dataset
file_path = "cbinsights_data.csv"  # Replace with your actual file path

try:
    data = pd.read_csv(file_path, skiprows=1)
    logger.info("CSV file loaded successfully.")
except FileNotFoundError:
    logger.error(f"File not found: {file_path}")
    raise
except Exception as e:
    logger.error(f"Error loading CSV file: {e}")
    raise

# Standardize column names
data.columns = data.columns.str.strip().str.lower()
logger.info(f"Standardized Column Names: {data.columns.tolist()}")

# Filter out 'Health' since 'Healthcare' is the correct Market Segment
data = data[data.industry != 'Health']

# Identify the valuation column
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
if len(valuation_columns) != 1:
    logger.error("Unable to identify a single valuation column.")
    raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")

valuation_column = valuation_columns[0]
logger.info(f"Using valuation column: {valuation_column}")

# Clean and prepare data
data["valuation_billions"] = data[valuation_column].apply(
    lambda x: float(re.sub(r"[^0-9.]", "", str(x))) / 1e9 if pd.notnull(x) else 0
)
logger.info("Valuation column cleaned and converted to billions.")

# Create a graph
G = nx.Graph()

for _, row in data.iterrows():
    company_name = row["company"]
    valuation = row["valuation_billions"]
    industry = row["industry"]

    # Add company node
    G.add_node(
        company_name,
        size=valuation,
        color="green" if industry == "Venture Firm" else "blue",
    )

    # Add connections based on relationships (assume relationships column exists)
    if "relationships" in data.columns:
        relationships = str(row["relationships"]).split(";")
        for relation in relationships:
            G.add_edge(company_name, relation.strip())

# Create Plotly visualization
node_sizes = [G.nodes[node]["size"] * 50 for node in G.nodes]
node_colors = [G.nodes[node]["color"] for node in G.nodes]

pos = nx.spring_layout(G)
x_coords = [pos[node][0] for node in G.nodes]
y_coords = [pos[node][1] for node in G.nodes]

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=x_coords,
        y=y_coords,
        mode="markers+text",
        marker=dict(size=node_sizes, color=node_colors, opacity=0.8),
        text=list(G.nodes),
        textposition="top center",
    )
)

fig.update_layout(
    title="Company Network Visualization",
    xaxis=dict(showgrid=False, zeroline=False),
    yaxis=dict(showgrid=False, zeroline=False),
    showlegend=False,
)

# Note: All companies are in green while venture firms have different colors. 
# The diameter of the company circle varies proportionate to the valuation of the corresponding company.

# Create Gradio interface
def display_network():
    return fig.to_html()

gr.Interface(
    fn=display_network,
    inputs=[],
    outputs="html",
    title="Company Network Visualization",
).launch()