LeonceNsh commited on
Commit
b951dc3
·
verified ·
1 Parent(s): cc9514f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -26
app.py CHANGED
@@ -4,31 +4,52 @@ import matplotlib.pyplot as plt
4
  from io import BytesIO
5
  from PIL import Image
6
  import gradio as gr
 
 
 
 
 
7
 
8
  # Load and preprocess the dataset
9
- file_path = "cbinsights_data.csv" # Replace with your file path
10
- data = pd.read_csv(file_path, skiprows=1)
 
 
 
 
 
 
 
 
 
11
 
12
  # Standardize column names: strip whitespace and convert to lowercase
13
  data.columns = data.columns.str.strip().str.lower()
14
- print("Standardized Column Names:", data.columns.tolist())
15
 
16
  # Identify the valuation column dynamically
17
  valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
18
  if not valuation_columns:
19
- raise ValueError("No column containing 'Valuation' found in the dataset.")
 
20
  elif len(valuation_columns) > 1:
21
- raise ValueError("Multiple columns containing 'Valuation' found. Please specify.")
 
22
  else:
23
  valuation_column = valuation_columns[0]
 
24
 
25
  # Clean and prepare data
26
  data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
27
  data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
28
- data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x)
 
 
 
 
29
 
30
- # Rename columns for consistency (optional)
31
- data = data.rename(columns={
32
  "company": "Company",
33
  "valuation_billions": "Valuation_Billions",
34
  "date_joined": "Date_Joined",
@@ -36,7 +57,15 @@ data = data.rename(columns={
36
  "city": "City",
37
  "industry": "Industry",
38
  "select_investors": "Select_Investors"
39
- })
 
 
 
 
 
 
 
 
40
 
41
  # Parse the "Select_Investors" column to map investors to companies
42
  def build_investor_company_mapping(df):
@@ -47,18 +76,24 @@ def build_investor_company_mapping(df):
47
  if pd.notnull(investors):
48
  for investor in investors.split(","):
49
  investor = investor.strip()
50
- mapping.setdefault(investor, []).append(company)
 
51
  return mapping
52
 
53
  investor_company_mapping = build_investor_company_mapping(data)
 
54
 
55
  # Function to filter investors based on selected country and industry
56
  def filter_investors_by_country_and_industry(selected_country, selected_industry):
57
  filtered_data = data.copy()
 
 
58
  if selected_country != "All":
59
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
 
60
  if selected_industry != "All":
61
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
 
62
 
63
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
64
 
@@ -69,22 +104,27 @@ def filter_investors_by_country_and_industry(selected_country, selected_industry
69
  if total_valuation >= 20: # Investors with >= 20B total valuation
70
  investor_valuations[investor] = total_valuation
71
 
 
 
72
  return list(investor_valuations.keys()), filtered_data
73
 
74
  # Function to generate the graph
75
  def generate_graph(selected_investors, filtered_data):
76
  if not selected_investors:
 
77
  return None
78
 
79
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
80
- filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors}
81
-
 
 
82
  # Build the graph
83
  G = nx.Graph()
84
  for investor, companies in filtered_mapping.items():
85
  for company in companies:
86
  G.add_edge(investor, company)
87
-
88
  # Node size based on valuation
89
  max_valuation = filtered_data["Valuation_Billions"].max()
90
  node_sizes = []
@@ -95,10 +135,10 @@ def generate_graph(selected_investors, filtered_data):
95
  valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
96
  size = (valuation / max_valuation) * 1500 if max_valuation else 100
97
  node_sizes.append(size)
98
-
99
  # Node color: Investors (orange), Companies (green)
100
  node_colors = ["#FF8C00" if node in filtered_mapping else "#32CD32" for node in G.nodes]
101
-
102
  # Draw the graph
103
  plt.figure(figsize=(15, 15))
104
  pos = nx.spring_layout(G, k=0.2, seed=42)
@@ -111,7 +151,7 @@ def generate_graph(selected_investors, filtered_data):
111
  edge_color="#A9A9A9", # Light gray edges
112
  alpha=0.9
113
  )
114
-
115
  # Legend
116
  from matplotlib.lines import Line2D
117
  legend_elements = [
@@ -119,22 +159,27 @@ def generate_graph(selected_investors, filtered_data):
119
  Line2D([0], [0], marker='o', color='w', label='Company', markersize=10, markerfacecolor='#32CD32')
120
  ]
121
  plt.legend(handles=legend_elements, loc='upper left')
122
-
123
  plt.title("Venture Network Visualization", fontsize=20)
124
  plt.axis("off")
125
-
126
  # Save plot to BytesIO
127
  buf = BytesIO()
128
  plt.savefig(buf, format="png", bbox_inches="tight")
129
  plt.close()
130
  buf.seek(0)
131
-
 
 
132
  return Image.open(buf)
133
 
134
  # Gradio app function
135
  def app(selected_country, selected_industry):
136
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
137
- return gr.CheckboxGroup.update(
 
 
 
138
  choices=investor_list,
139
  value=investor_list,
140
  visible=True
@@ -144,17 +189,20 @@ def app(selected_country, selected_industry):
144
  def main():
145
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
146
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
147
-
 
 
 
148
  with gr.Blocks() as demo:
149
  with gr.Row():
150
  country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
151
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
152
-
153
  filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors", visible=False)
154
  graph_output = gr.Image(type="pil", label="Venture Network Graph")
155
-
156
  filtered_data_holder = gr.State()
157
-
158
  country_filter.change(
159
  app,
160
  inputs=[country_filter, industry_filter],
@@ -165,13 +213,13 @@ def main():
165
  inputs=[country_filter, industry_filter],
166
  outputs=[filtered_investor_list, filtered_data_holder]
167
  )
168
-
169
  filtered_investor_list.change(
170
  generate_graph,
171
  inputs=[filtered_investor_list, filtered_data_holder],
172
  outputs=graph_output
173
  )
174
-
175
  demo.launch()
176
 
177
  if __name__ == "__main__":
 
4
  from io import BytesIO
5
  from PIL import Image
6
  import gradio as gr
7
+ import logging
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
  # Load and preprocess the dataset
14
+ file_path = "cbinsights_data.csv" # Replace with your actual file path
15
+
16
+ try:
17
+ data = pd.read_csv(file_path)
18
+ logger.info("CSV file loaded successfully.")
19
+ except FileNotFoundError:
20
+ logger.error(f"File not found: {file_path}")
21
+ raise
22
+ except Exception as e:
23
+ logger.error(f"Error loading CSV file: {e}")
24
+ raise
25
 
26
  # Standardize column names: strip whitespace and convert to lowercase
27
  data.columns = data.columns.str.strip().str.lower()
28
+ logger.info(f"Standardized Column Names: {data.columns.tolist()}")
29
 
30
  # Identify the valuation column dynamically
31
  valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
32
  if not valuation_columns:
33
+ logger.error("No column containing 'Valuation' found in the dataset.")
34
+ raise ValueError("Data Error: Unable to find the valuation column. Please check your CSV file.")
35
  elif len(valuation_columns) > 1:
36
+ logger.error("Multiple columns containing 'Valuation' found in the dataset.")
37
+ raise ValueError("Data Error: Multiple valuation columns detected. Please ensure only one valuation column exists.")
38
  else:
39
  valuation_column = valuation_columns[0]
40
+ logger.info(f"Using valuation column: {valuation_column}")
41
 
42
  # Clean and prepare data
43
  data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
44
  data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
45
+ logger.info("Valuation data cleaned and converted to numeric.")
46
+
47
+ # Strip whitespace from all string columns
48
+ data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
49
+ logger.info("Whitespace stripped from all string columns.")
50
 
51
+ # Rename columns for consistency
52
+ expected_columns = {
53
  "company": "Company",
54
  "valuation_billions": "Valuation_Billions",
55
  "date_joined": "Date_Joined",
 
57
  "city": "City",
58
  "industry": "Industry",
59
  "select_investors": "Select_Investors"
60
+ }
61
+
62
+ missing_columns = set(expected_columns.keys()) - set(data.columns)
63
+ if missing_columns:
64
+ logger.error(f"Missing columns in the dataset: {missing_columns}")
65
+ raise ValueError(f"Data Error: Missing columns {missing_columns} in the dataset.")
66
+
67
+ data = data.rename(columns=expected_columns)
68
+ logger.info("Columns renamed for consistency.")
69
 
70
  # Parse the "Select_Investors" column to map investors to companies
71
  def build_investor_company_mapping(df):
 
76
  if pd.notnull(investors):
77
  for investor in investors.split(","):
78
  investor = investor.strip()
79
+ if investor: # Ensure investor is not an empty string
80
+ mapping.setdefault(investor, []).append(company)
81
  return mapping
82
 
83
  investor_company_mapping = build_investor_company_mapping(data)
84
+ logger.info("Investor to company mapping created.")
85
 
86
  # Function to filter investors based on selected country and industry
87
  def filter_investors_by_country_and_industry(selected_country, selected_industry):
88
  filtered_data = data.copy()
89
+ logger.info(f"Filtering data for Country: {selected_country}, Industry: {selected_industry}")
90
+
91
  if selected_country != "All":
92
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
93
+ logger.info(f"Data filtered by country: {selected_country}. Remaining records: {len(filtered_data)}")
94
  if selected_industry != "All":
95
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
96
+ logger.info(f"Data filtered by industry: {selected_industry}. Remaining records: {len(filtered_data)}")
97
 
98
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
99
 
 
104
  if total_valuation >= 20: # Investors with >= 20B total valuation
105
  investor_valuations[investor] = total_valuation
106
 
107
+ logger.info(f"Filtered investors with total valuation >= 20B: {len(investor_valuations)}")
108
+
109
  return list(investor_valuations.keys()), filtered_data
110
 
111
  # Function to generate the graph
112
  def generate_graph(selected_investors, filtered_data):
113
  if not selected_investors:
114
+ logger.warning("No investors selected. Returning None for graph.")
115
  return None
116
 
117
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
118
+ filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors if inv in investor_company_mapping_filtered}
119
+
120
+ logger.info(f"Generating graph for {len(filtered_mapping)} investors.")
121
+
122
  # Build the graph
123
  G = nx.Graph()
124
  for investor, companies in filtered_mapping.items():
125
  for company in companies:
126
  G.add_edge(investor, company)
127
+
128
  # Node size based on valuation
129
  max_valuation = filtered_data["Valuation_Billions"].max()
130
  node_sizes = []
 
135
  valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
136
  size = (valuation / max_valuation) * 1500 if max_valuation else 100
137
  node_sizes.append(size)
138
+
139
  # Node color: Investors (orange), Companies (green)
140
  node_colors = ["#FF8C00" if node in filtered_mapping else "#32CD32" for node in G.nodes]
141
+
142
  # Draw the graph
143
  plt.figure(figsize=(15, 15))
144
  pos = nx.spring_layout(G, k=0.2, seed=42)
 
151
  edge_color="#A9A9A9", # Light gray edges
152
  alpha=0.9
153
  )
154
+
155
  # Legend
156
  from matplotlib.lines import Line2D
157
  legend_elements = [
 
159
  Line2D([0], [0], marker='o', color='w', label='Company', markersize=10, markerfacecolor='#32CD32')
160
  ]
161
  plt.legend(handles=legend_elements, loc='upper left')
162
+
163
  plt.title("Venture Network Visualization", fontsize=20)
164
  plt.axis("off")
165
+
166
  # Save plot to BytesIO
167
  buf = BytesIO()
168
  plt.savefig(buf, format="png", bbox_inches="tight")
169
  plt.close()
170
  buf.seek(0)
171
+
172
+ logger.info("Graph generated successfully.")
173
+
174
  return Image.open(buf)
175
 
176
  # Gradio app function
177
  def app(selected_country, selected_industry):
178
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
179
+ logger.info("Updating CheckboxGroup and filtered data holder.")
180
+
181
+ # Use gr.update() to create an update dictionary for the CheckboxGroup
182
+ return gr.update(
183
  choices=investor_list,
184
  value=investor_list,
185
  visible=True
 
189
  def main():
190
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
191
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
192
+
193
+ logger.info(f"Available countries: {country_list}")
194
+ logger.info(f"Available industries: {industry_list}")
195
+
196
  with gr.Blocks() as demo:
197
  with gr.Row():
198
  country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
199
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
200
+
201
  filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors", visible=False)
202
  graph_output = gr.Image(type="pil", label="Venture Network Graph")
203
+
204
  filtered_data_holder = gr.State()
205
+
206
  country_filter.change(
207
  app,
208
  inputs=[country_filter, industry_filter],
 
213
  inputs=[country_filter, industry_filter],
214
  outputs=[filtered_investor_list, filtered_data_holder]
215
  )
216
+
217
  filtered_investor_list.change(
218
  generate_graph,
219
  inputs=[filtered_investor_list, filtered_data_holder],
220
  outputs=graph_output
221
  )
222
+
223
  demo.launch()
224
 
225
  if __name__ == "__main__":