File size: 6,407 Bytes
a165958
45a7450
 
 
9349152
a165958
45a7450
a165958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9327810
 
e63418f
 
9327810
e63418f
 
 
 
45a7450
9327810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e63418f
 
 
9327810
 
 
 
e63418f
 
 
 
45a7450
a165958
45a7450
a165958
 
 
e63418f
a165958
 
e63418f
09b019d
45a7450
a165958
45a7450
 
e63418f
a165958
45a7450
a165958
45a7450
a165958
45a7450
a165958
45a7450
 
 
99eb020
45a7450
09b019d
 
45a7450
09b019d
 
45a7450
09b019d
 
 
 
 
 
 
a165958
45a7450
 
a165958
45a7450
 
 
 
9349152
 
 
 
45a7450
9327810
 
 
 
 
 
 
 
 
a165958
45a7450
e63418f
 
45a7450
9327810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45a7450
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
import gradio as gr

# Load and preprocess the dataset
file_path = "cbinsights_data.csv"  # Replace with your file path
data = pd.read_csv(file_path)

# Rename columns based on the first row and drop the header row
data.columns = data.iloc[0]
data = data[1:]
data.columns = ["Company", "Valuation_Billions", "Date_Joined", "Country", "City", "Industry", "Select_Investors"]

# Clean and prepare data
data["Valuation_Billions"] = data["Valuation_Billions"].str.replace('$', '').str.split('.').str[0]
data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x)

# Parse the "Select_Investors" column to map investors to companies
investor_company_mapping = {}
for _, row in data.iterrows():
    company = row["Company"]
    investors = row["Select_Investors"]
    if pd.notnull(investors):
        for investor in investors.split(","):
            investor = investor.strip()
            if investor not in investor_company_mapping:
                investor_company_mapping[investor] = []
            investor_company_mapping[investor].append(company)

# Gradio app functions
def filter_investors_by_country_and_industry(selected_country, selected_industry):
    filtered_data = data

    # Apply filters
    if selected_country != "All":
        filtered_data = filtered_data[filtered_data["Country"] == selected_country]
    if selected_industry != "All":
        filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]

    # Calculate total valuation per investor
    investor_valuations = {}
    for investor, companies in investor_company_mapping.items():
        total_valuation = 0
        for company in companies:
            if company in filtered_data["Company"].values:
                valuation = filtered_data.loc[filtered_data["Company"] == company, "Valuation_Billions"].values
                total_valuation += valuation[0] if len(valuation) > 0 else 0
        if total_valuation >= 20:  # Filter by total valuation
            investor_valuations[investor] = total_valuation

    return list(investor_valuations.keys()), filtered_data

def generate_graph(selected_investors, filtered_data):
    if not selected_investors:
        return None

    # Filter the investor-to-company mapping
    filtered_mapping = {}
    for investor, companies in investor_company_mapping.items():
        if investor in selected_investors:
            filtered_companies = [c for c in companies if c in filtered_data["Company"].values]
            if filtered_companies:
                filtered_mapping[investor] = filtered_companies

    # Use the filtered mapping to build the graph
    G = nx.Graph()
    for investor, companies in filtered_mapping.items():
        for company in companies:
            G.add_edge(investor, company)

    # Node sizes based on valuation
    node_sizes = []
    for node in G.nodes:
        if node in filtered_mapping:  # Fixed size for investors
            node_sizes.append(2000)
        else:  # Company size based on valuation
            valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
            node_sizes.append(valuation[0] * 50 if len(valuation) > 0 else 100)

    # Node colors
    node_colors = []
    for node in G.nodes:
        if node in filtered_mapping:
            node_colors.append("#FF5733")  # Distinct color for investors
        else:
            node_colors.append("#33FF57")  # Distinct color for companies

    # Create the graph plot
    plt.figure(figsize=(18, 18))
    pos = nx.spring_layout(G, k=0.2, seed=42)  # Fixed seed for consistent layout
    nx.draw(
        G, pos,
        with_labels=True,
        node_size=node_sizes,
        node_color=node_colors,
        alpha=0.8,  # Slight transparency for Tufte-inspired visuals
        font_size=10,
        font_weight="bold",
        edge_color="#B0BEC5",  # Neutral, muted edge color
        width=0.8  # Thin edges for minimal visual clutter
    )

    # Add a legend for node size (valuation)
    min_size, max_size = 50, 5000  # Example scale
    for size, label in zip([min_size, max_size], ["$1B", "$100B"]):
        plt.scatter([], [], s=size, color="#33FF57", label=f"{label} valuation")
    plt.legend(scatterpoints=1, frameon=False, labelspacing=1.5, loc="lower left", fontsize=12)

    plt.title("Venture Funded Companies Visualization", fontsize=20)
    plt.axis('off')

    # Save plot to BytesIO object
    buf = BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight")
    plt.close()
    buf.seek(0)

    # Convert BytesIO to PIL image
    image = Image.open(buf)
    return image

def app(selected_country, selected_industry):
    investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)

    return gr.update(
        choices=investor_list,
        value=investor_list,
        visible=True
    ), filtered_data

# Gradio Interface
def main():
    country_list = ["All"] + sorted(data["Country"].dropna().unique())
    industry_list = ["All"] + sorted(data["Industry"].dropna().unique())

    with gr.Blocks() as demo:
        with gr.Row():
            country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
            industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")

        filtered_investor_list = gr.CheckboxGroup(
            choices=[],
            label="Select Investors",
            visible=False
        )
        graph_output = gr.Image(type="pil", label="Venture Network Graph")

        filtered_data_holder = gr.State()

        country_filter.change(
            app,
            inputs=[country_filter, industry_filter],
            outputs=[filtered_investor_list, filtered_data_holder]
        )
        industry_filter.change(
            app,
            inputs=[country_filter, industry_filter],
            outputs=[filtered_investor_list, filtered_data_holder]
        )

        filtered_investor_list.change(
            generate_graph,
            inputs=[filtered_investor_list, filtered_data_holder],
            outputs=graph_output
        )

    demo.launch()

if __name__ == "__main__":
    main()