Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,31 +4,52 @@ import matplotlib.pyplot as plt
|
|
4 |
from io import BytesIO
|
5 |
from PIL import Image
|
6 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Load and preprocess the dataset
|
9 |
-
file_path = "cbinsights_data.csv" # Replace with your file path
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Standardize column names: strip whitespace and convert to lowercase
|
13 |
data.columns = data.columns.str.strip().str.lower()
|
14 |
-
|
15 |
|
16 |
# Identify the valuation column dynamically
|
17 |
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
|
18 |
if not valuation_columns:
|
19 |
-
|
|
|
20 |
elif len(valuation_columns) > 1:
|
21 |
-
|
|
|
22 |
else:
|
23 |
valuation_column = valuation_columns[0]
|
|
|
24 |
|
25 |
# Clean and prepare data
|
26 |
data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
|
27 |
data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
# Rename columns for consistency
|
31 |
-
|
32 |
"company": "Company",
|
33 |
"valuation_billions": "Valuation_Billions",
|
34 |
"date_joined": "Date_Joined",
|
@@ -36,7 +57,15 @@ data = data.rename(columns={
|
|
36 |
"city": "City",
|
37 |
"industry": "Industry",
|
38 |
"select_investors": "Select_Investors"
|
39 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Parse the "Select_Investors" column to map investors to companies
|
42 |
def build_investor_company_mapping(df):
|
@@ -47,18 +76,24 @@ def build_investor_company_mapping(df):
|
|
47 |
if pd.notnull(investors):
|
48 |
for investor in investors.split(","):
|
49 |
investor = investor.strip()
|
50 |
-
|
|
|
51 |
return mapping
|
52 |
|
53 |
investor_company_mapping = build_investor_company_mapping(data)
|
|
|
54 |
|
55 |
# Function to filter investors based on selected country and industry
|
56 |
def filter_investors_by_country_and_industry(selected_country, selected_industry):
|
57 |
filtered_data = data.copy()
|
|
|
|
|
58 |
if selected_country != "All":
|
59 |
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
|
|
|
60 |
if selected_industry != "All":
|
61 |
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
|
|
|
62 |
|
63 |
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
|
64 |
|
@@ -69,22 +104,27 @@ def filter_investors_by_country_and_industry(selected_country, selected_industry
|
|
69 |
if total_valuation >= 20: # Investors with >= 20B total valuation
|
70 |
investor_valuations[investor] = total_valuation
|
71 |
|
|
|
|
|
72 |
return list(investor_valuations.keys()), filtered_data
|
73 |
|
74 |
# Function to generate the graph
|
75 |
def generate_graph(selected_investors, filtered_data):
|
76 |
if not selected_investors:
|
|
|
77 |
return None
|
78 |
|
79 |
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
|
80 |
-
filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors}
|
81 |
-
|
|
|
|
|
82 |
# Build the graph
|
83 |
G = nx.Graph()
|
84 |
for investor, companies in filtered_mapping.items():
|
85 |
for company in companies:
|
86 |
G.add_edge(investor, company)
|
87 |
-
|
88 |
# Node size based on valuation
|
89 |
max_valuation = filtered_data["Valuation_Billions"].max()
|
90 |
node_sizes = []
|
@@ -95,10 +135,10 @@ def generate_graph(selected_investors, filtered_data):
|
|
95 |
valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
|
96 |
size = (valuation / max_valuation) * 1500 if max_valuation else 100
|
97 |
node_sizes.append(size)
|
98 |
-
|
99 |
# Node color: Investors (orange), Companies (green)
|
100 |
node_colors = ["#FF8C00" if node in filtered_mapping else "#32CD32" for node in G.nodes]
|
101 |
-
|
102 |
# Draw the graph
|
103 |
plt.figure(figsize=(15, 15))
|
104 |
pos = nx.spring_layout(G, k=0.2, seed=42)
|
@@ -111,7 +151,7 @@ def generate_graph(selected_investors, filtered_data):
|
|
111 |
edge_color="#A9A9A9", # Light gray edges
|
112 |
alpha=0.9
|
113 |
)
|
114 |
-
|
115 |
# Legend
|
116 |
from matplotlib.lines import Line2D
|
117 |
legend_elements = [
|
@@ -119,22 +159,27 @@ def generate_graph(selected_investors, filtered_data):
|
|
119 |
Line2D([0], [0], marker='o', color='w', label='Company', markersize=10, markerfacecolor='#32CD32')
|
120 |
]
|
121 |
plt.legend(handles=legend_elements, loc='upper left')
|
122 |
-
|
123 |
plt.title("Venture Network Visualization", fontsize=20)
|
124 |
plt.axis("off")
|
125 |
-
|
126 |
# Save plot to BytesIO
|
127 |
buf = BytesIO()
|
128 |
plt.savefig(buf, format="png", bbox_inches="tight")
|
129 |
plt.close()
|
130 |
buf.seek(0)
|
131 |
-
|
|
|
|
|
132 |
return Image.open(buf)
|
133 |
|
134 |
# Gradio app function
|
135 |
def app(selected_country, selected_industry):
|
136 |
investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
|
137 |
-
|
|
|
|
|
|
|
138 |
choices=investor_list,
|
139 |
value=investor_list,
|
140 |
visible=True
|
@@ -144,17 +189,20 @@ def app(selected_country, selected_industry):
|
|
144 |
def main():
|
145 |
country_list = ["All"] + sorted(data["Country"].dropna().unique())
|
146 |
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
|
147 |
-
|
|
|
|
|
|
|
148 |
with gr.Blocks() as demo:
|
149 |
with gr.Row():
|
150 |
country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
|
151 |
industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
|
152 |
-
|
153 |
filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors", visible=False)
|
154 |
graph_output = gr.Image(type="pil", label="Venture Network Graph")
|
155 |
-
|
156 |
filtered_data_holder = gr.State()
|
157 |
-
|
158 |
country_filter.change(
|
159 |
app,
|
160 |
inputs=[country_filter, industry_filter],
|
@@ -165,13 +213,13 @@ def main():
|
|
165 |
inputs=[country_filter, industry_filter],
|
166 |
outputs=[filtered_investor_list, filtered_data_holder]
|
167 |
)
|
168 |
-
|
169 |
filtered_investor_list.change(
|
170 |
generate_graph,
|
171 |
inputs=[filtered_investor_list, filtered_data_holder],
|
172 |
outputs=graph_output
|
173 |
)
|
174 |
-
|
175 |
demo.launch()
|
176 |
|
177 |
if __name__ == "__main__":
|
|
|
4 |
from io import BytesIO
|
5 |
from PIL import Image
|
6 |
import gradio as gr
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# Set up logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
|
13 |
# Load and preprocess the dataset
|
14 |
+
file_path = "cbinsights_data.csv" # Replace with your actual file path
|
15 |
+
|
16 |
+
try:
|
17 |
+
data = pd.read_csv(file_path)
|
18 |
+
logger.info("CSV file loaded successfully.")
|
19 |
+
except FileNotFoundError:
|
20 |
+
logger.error(f"File not found: {file_path}")
|
21 |
+
raise
|
22 |
+
except Exception as e:
|
23 |
+
logger.error(f"Error loading CSV file: {e}")
|
24 |
+
raise
|
25 |
|
26 |
# Standardize column names: strip whitespace and convert to lowercase
|
27 |
data.columns = data.columns.str.strip().str.lower()
|
28 |
+
logger.info(f"Standardized Column Names: {data.columns.tolist()}")
|
29 |
|
30 |
# Identify the valuation column dynamically
|
31 |
valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
|
32 |
if not valuation_columns:
|
33 |
+
logger.error("No column containing 'Valuation' found in the dataset.")
|
34 |
+
raise ValueError("Data Error: Unable to find the valuation column. Please check your CSV file.")
|
35 |
elif len(valuation_columns) > 1:
|
36 |
+
logger.error("Multiple columns containing 'Valuation' found in the dataset.")
|
37 |
+
raise ValueError("Data Error: Multiple valuation columns detected. Please ensure only one valuation column exists.")
|
38 |
else:
|
39 |
valuation_column = valuation_columns[0]
|
40 |
+
logger.info(f"Using valuation column: {valuation_column}")
|
41 |
|
42 |
# Clean and prepare data
|
43 |
data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
|
44 |
data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
|
45 |
+
logger.info("Valuation data cleaned and converted to numeric.")
|
46 |
+
|
47 |
+
# Strip whitespace from all string columns
|
48 |
+
data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
|
49 |
+
logger.info("Whitespace stripped from all string columns.")
|
50 |
|
51 |
+
# Rename columns for consistency
|
52 |
+
expected_columns = {
|
53 |
"company": "Company",
|
54 |
"valuation_billions": "Valuation_Billions",
|
55 |
"date_joined": "Date_Joined",
|
|
|
57 |
"city": "City",
|
58 |
"industry": "Industry",
|
59 |
"select_investors": "Select_Investors"
|
60 |
+
}
|
61 |
+
|
62 |
+
missing_columns = set(expected_columns.keys()) - set(data.columns)
|
63 |
+
if missing_columns:
|
64 |
+
logger.error(f"Missing columns in the dataset: {missing_columns}")
|
65 |
+
raise ValueError(f"Data Error: Missing columns {missing_columns} in the dataset.")
|
66 |
+
|
67 |
+
data = data.rename(columns=expected_columns)
|
68 |
+
logger.info("Columns renamed for consistency.")
|
69 |
|
70 |
# Parse the "Select_Investors" column to map investors to companies
|
71 |
def build_investor_company_mapping(df):
|
|
|
76 |
if pd.notnull(investors):
|
77 |
for investor in investors.split(","):
|
78 |
investor = investor.strip()
|
79 |
+
if investor: # Ensure investor is not an empty string
|
80 |
+
mapping.setdefault(investor, []).append(company)
|
81 |
return mapping
|
82 |
|
83 |
investor_company_mapping = build_investor_company_mapping(data)
|
84 |
+
logger.info("Investor to company mapping created.")
|
85 |
|
86 |
# Function to filter investors based on selected country and industry
|
87 |
def filter_investors_by_country_and_industry(selected_country, selected_industry):
|
88 |
filtered_data = data.copy()
|
89 |
+
logger.info(f"Filtering data for Country: {selected_country}, Industry: {selected_industry}")
|
90 |
+
|
91 |
if selected_country != "All":
|
92 |
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
|
93 |
+
logger.info(f"Data filtered by country: {selected_country}. Remaining records: {len(filtered_data)}")
|
94 |
if selected_industry != "All":
|
95 |
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
|
96 |
+
logger.info(f"Data filtered by industry: {selected_industry}. Remaining records: {len(filtered_data)}")
|
97 |
|
98 |
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
|
99 |
|
|
|
104 |
if total_valuation >= 20: # Investors with >= 20B total valuation
|
105 |
investor_valuations[investor] = total_valuation
|
106 |
|
107 |
+
logger.info(f"Filtered investors with total valuation >= 20B: {len(investor_valuations)}")
|
108 |
+
|
109 |
return list(investor_valuations.keys()), filtered_data
|
110 |
|
111 |
# Function to generate the graph
|
112 |
def generate_graph(selected_investors, filtered_data):
|
113 |
if not selected_investors:
|
114 |
+
logger.warning("No investors selected. Returning None for graph.")
|
115 |
return None
|
116 |
|
117 |
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
|
118 |
+
filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors if inv in investor_company_mapping_filtered}
|
119 |
+
|
120 |
+
logger.info(f"Generating graph for {len(filtered_mapping)} investors.")
|
121 |
+
|
122 |
# Build the graph
|
123 |
G = nx.Graph()
|
124 |
for investor, companies in filtered_mapping.items():
|
125 |
for company in companies:
|
126 |
G.add_edge(investor, company)
|
127 |
+
|
128 |
# Node size based on valuation
|
129 |
max_valuation = filtered_data["Valuation_Billions"].max()
|
130 |
node_sizes = []
|
|
|
135 |
valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
|
136 |
size = (valuation / max_valuation) * 1500 if max_valuation else 100
|
137 |
node_sizes.append(size)
|
138 |
+
|
139 |
# Node color: Investors (orange), Companies (green)
|
140 |
node_colors = ["#FF8C00" if node in filtered_mapping else "#32CD32" for node in G.nodes]
|
141 |
+
|
142 |
# Draw the graph
|
143 |
plt.figure(figsize=(15, 15))
|
144 |
pos = nx.spring_layout(G, k=0.2, seed=42)
|
|
|
151 |
edge_color="#A9A9A9", # Light gray edges
|
152 |
alpha=0.9
|
153 |
)
|
154 |
+
|
155 |
# Legend
|
156 |
from matplotlib.lines import Line2D
|
157 |
legend_elements = [
|
|
|
159 |
Line2D([0], [0], marker='o', color='w', label='Company', markersize=10, markerfacecolor='#32CD32')
|
160 |
]
|
161 |
plt.legend(handles=legend_elements, loc='upper left')
|
162 |
+
|
163 |
plt.title("Venture Network Visualization", fontsize=20)
|
164 |
plt.axis("off")
|
165 |
+
|
166 |
# Save plot to BytesIO
|
167 |
buf = BytesIO()
|
168 |
plt.savefig(buf, format="png", bbox_inches="tight")
|
169 |
plt.close()
|
170 |
buf.seek(0)
|
171 |
+
|
172 |
+
logger.info("Graph generated successfully.")
|
173 |
+
|
174 |
return Image.open(buf)
|
175 |
|
176 |
# Gradio app function
|
177 |
def app(selected_country, selected_industry):
|
178 |
investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
|
179 |
+
logger.info("Updating CheckboxGroup and filtered data holder.")
|
180 |
+
|
181 |
+
# Use gr.update() to create an update dictionary for the CheckboxGroup
|
182 |
+
return gr.update(
|
183 |
choices=investor_list,
|
184 |
value=investor_list,
|
185 |
visible=True
|
|
|
189 |
def main():
|
190 |
country_list = ["All"] + sorted(data["Country"].dropna().unique())
|
191 |
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
|
192 |
+
|
193 |
+
logger.info(f"Available countries: {country_list}")
|
194 |
+
logger.info(f"Available industries: {industry_list}")
|
195 |
+
|
196 |
with gr.Blocks() as demo:
|
197 |
with gr.Row():
|
198 |
country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
|
199 |
industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
|
200 |
+
|
201 |
filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors", visible=False)
|
202 |
graph_output = gr.Image(type="pil", label="Venture Network Graph")
|
203 |
+
|
204 |
filtered_data_holder = gr.State()
|
205 |
+
|
206 |
country_filter.change(
|
207 |
app,
|
208 |
inputs=[country_filter, industry_filter],
|
|
|
213 |
inputs=[country_filter, industry_filter],
|
214 |
outputs=[filtered_investor_list, filtered_data_holder]
|
215 |
)
|
216 |
+
|
217 |
filtered_investor_list.change(
|
218 |
generate_graph,
|
219 |
inputs=[filtered_investor_list, filtered_data_holder],
|
220 |
outputs=graph_output
|
221 |
)
|
222 |
+
|
223 |
demo.launch()
|
224 |
|
225 |
if __name__ == "__main__":
|