File size: 6,557 Bytes
a165958
45a7450
 
 
9349152
a165958
45a7450
a165958
 
cc9514f
a165958
1322835
 
 
 
 
 
 
 
 
 
 
 
 
a165958
1322835
 
a165958
 
1322835
 
 
 
 
 
 
 
 
 
 
 
9aa537c
 
 
 
1322835
9aa537c
 
 
 
 
 
1322835
 
9aa537c
9327810
9aa537c
e63418f
 
 
 
9aa537c
 
 
9327810
 
9aa537c
 
1322835
9327810
9aa537c
9327810
 
9aa537c
9327810
 
 
 
9aa537c
 
e63418f
9aa537c
e63418f
 
45a7450
a165958
45a7450
9aa537c
 
a165958
45a7450
e63418f
1322835
45a7450
9aa537c
1322835
9aa537c
 
 
 
45a7450
9aa537c
 
 
45a7450
 
 
99eb020
45a7450
09b019d
1322835
9aa537c
45a7450
09b019d
9aa537c
 
 
 
 
 
 
09b019d
9aa537c
 
45a7450
9aa537c
45a7450
 
 
 
9349152
9aa537c
45a7450
9aa537c
9327810
 
9aa537c
9327810
 
 
 
 
a165958
45a7450
e63418f
 
45a7450
9327810
 
 
 
 
9aa537c
9327810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45a7450
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
import gradio as gr

# Load and preprocess the dataset
file_path = "cbinsights_data.csv"  # Replace with your file path
data = pd.read_csv(file_path, skiprows=1)

# Standardize column names: strip whitespace and convert to lowercase
data.columns = data.columns.str.strip().str.lower()
print("Standardized Column Names:", data.columns.tolist())

# Identify the valuation column dynamically
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
if not valuation_columns:
    raise ValueError("No column containing 'Valuation' found in the dataset.")
elif len(valuation_columns) > 1:
    raise ValueError("Multiple columns containing 'Valuation' found. Please specify.")
else:
    valuation_column = valuation_columns[0]

# Clean and prepare data
data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x)

# Rename columns for consistency (optional)
data = data.rename(columns={
    "company": "Company",
    "valuation_billions": "Valuation_Billions",
    "date_joined": "Date_Joined",
    "country": "Country",
    "city": "City",
    "industry": "Industry",
    "select_investors": "Select_Investors"
})

# Parse the "Select_Investors" column to map investors to companies
def build_investor_company_mapping(df):
    mapping = {}
    for _, row in df.iterrows():
        company = row["Company"]
        investors = row["Select_Investors"]
        if pd.notnull(investors):
            for investor in investors.split(","):
                investor = investor.strip()
                mapping.setdefault(investor, []).append(company)
    return mapping

investor_company_mapping = build_investor_company_mapping(data)

# Function to filter investors based on selected country and industry
def filter_investors_by_country_and_industry(selected_country, selected_industry):
    filtered_data = data.copy()
    if selected_country != "All":
        filtered_data = filtered_data[filtered_data["Country"] == selected_country]
    if selected_industry != "All":
        filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
    
    investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
    
    # Calculate total valuation per investor
    investor_valuations = {}
    for investor, companies in investor_company_mapping_filtered.items():
        total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
        if total_valuation >= 20:  # Investors with >= 20B total valuation
            investor_valuations[investor] = total_valuation
    
    return list(investor_valuations.keys()), filtered_data

# Function to generate the graph
def generate_graph(selected_investors, filtered_data):
    if not selected_investors:
        return None

    investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
    filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors}

    # Build the graph
    G = nx.Graph()
    for investor, companies in filtered_mapping.items():
        for company in companies:
            G.add_edge(investor, company)

    # Node size based on valuation
    max_valuation = filtered_data["Valuation_Billions"].max()
    node_sizes = []
    for node in G.nodes:
        if node in filtered_mapping:
            node_sizes.append(1500)  # Fixed size for investors
        else:
            valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
            size = (valuation / max_valuation) * 1500 if max_valuation else 100
            node_sizes.append(size)

    # Node color: Investors (orange), Companies (green)
    node_colors = ["#FF8C00" if node in filtered_mapping else "#32CD32" for node in G.nodes]

    # Draw the graph
    plt.figure(figsize=(15, 15))
    pos = nx.spring_layout(G, k=0.2, seed=42)
    nx.draw(
        G, pos,
        with_labels=True,
        node_size=node_sizes,
        node_color=node_colors,
        font_size=10,
        edge_color="#A9A9A9",  # Light gray edges
        alpha=0.9
    )

    # Legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', label='Investor', markersize=10, markerfacecolor='#FF8C00'),
        Line2D([0], [0], marker='o', color='w', label='Company', markersize=10, markerfacecolor='#32CD32')
    ]
    plt.legend(handles=legend_elements, loc='upper left')

    plt.title("Venture Network Visualization", fontsize=20)
    plt.axis("off")

    # Save plot to BytesIO
    buf = BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight")
    plt.close()
    buf.seek(0)

    return Image.open(buf)

# Gradio app function
def app(selected_country, selected_industry):
    investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
    return gr.CheckboxGroup.update(
        choices=investor_list,
        value=investor_list,
        visible=True
    ), filtered_data

# Gradio Interface
def main():
    country_list = ["All"] + sorted(data["Country"].dropna().unique())
    industry_list = ["All"] + sorted(data["Industry"].dropna().unique())

    with gr.Blocks() as demo:
        with gr.Row():
            country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
            industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")

        filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors", visible=False)
        graph_output = gr.Image(type="pil", label="Venture Network Graph")

        filtered_data_holder = gr.State()

        country_filter.change(
            app,
            inputs=[country_filter, industry_filter],
            outputs=[filtered_investor_list, filtered_data_holder]
        )
        industry_filter.change(
            app,
            inputs=[country_filter, industry_filter],
            outputs=[filtered_investor_list, filtered_data_holder]
        )

        filtered_investor_list.change(
            generate_graph,
            inputs=[filtered_investor_list, filtered_data_holder],
            outputs=graph_output
        )

    demo.launch()

if __name__ == "__main__":
    main()