LeonceNsh commited on
Commit
1322835
·
verified ·
1 Parent(s): 9aa537c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -9
app.py CHANGED
@@ -9,24 +9,49 @@ import gradio as gr
9
  file_path = "cbinsights_data.csv" # Replace with your file path
10
  data = pd.read_csv(file_path)
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Clean and prepare data
13
- data.columns = data.columns.str.strip()
14
- data["Valuation_Billions"] = data["Valuation ($B)"].replace({'\$': '', ',': ''}, regex=True)
15
- data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
16
  data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x)
17
 
18
- # Parse the "Select Investors" column to map investors to companies
 
 
 
 
 
 
 
 
 
 
 
19
  def build_investor_company_mapping(df):
20
  mapping = {}
21
  for _, row in df.iterrows():
22
  company = row["Company"]
23
- investors = row["Select Investors"]
24
  if pd.notnull(investors):
25
  for investor in investors.split(","):
26
  investor = investor.strip()
27
  mapping.setdefault(investor, []).append(company)
28
  return mapping
29
 
 
 
30
  # Function to filter investors based on selected country and industry
31
  def filter_investors_by_country_and_industry(selected_country, selected_industry):
32
  filtered_data = data.copy()
@@ -41,7 +66,7 @@ def filter_investors_by_country_and_industry(selected_country, selected_industry
41
  investor_valuations = {}
42
  for investor, companies in investor_company_mapping_filtered.items():
43
  total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
44
- if total_valuation >= 20:
45
  investor_valuations[investor] = total_valuation
46
 
47
  return list(investor_valuations.keys()), filtered_data
@@ -65,10 +90,10 @@ def generate_graph(selected_investors, filtered_data):
65
  node_sizes = []
66
  for node in G.nodes:
67
  if node in filtered_mapping:
68
- node_sizes.append(1500)
69
  else:
70
  valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
71
- size = (valuation / max_valuation) * 1500
72
  node_sizes.append(size)
73
 
74
  # Node color: Investors (orange), Companies (green)
@@ -83,7 +108,7 @@ def generate_graph(selected_investors, filtered_data):
83
  node_size=node_sizes,
84
  node_color=node_colors,
85
  font_size=10,
86
- edge_color="#A9A9A9",
87
  alpha=0.9
88
  )
89
 
 
9
  file_path = "cbinsights_data.csv" # Replace with your file path
10
  data = pd.read_csv(file_path)
11
 
12
+ # Standardize column names: strip whitespace and convert to lowercase
13
+ data.columns = data.columns.str.strip().str.lower()
14
+ print("Standardized Column Names:", data.columns.tolist())
15
+
16
+ # Identify the valuation column dynamically
17
+ valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
18
+ if not valuation_columns:
19
+ raise ValueError("No column containing 'Valuation' found in the dataset.")
20
+ elif len(valuation_columns) > 1:
21
+ raise ValueError("Multiple columns containing 'Valuation' found. Please specify.")
22
+ else:
23
+ valuation_column = valuation_columns[0]
24
+
25
  # Clean and prepare data
26
+ data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
27
+ data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
 
28
  data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x)
29
 
30
+ # Rename columns for consistency (optional)
31
+ data = data.rename(columns={
32
+ "company": "Company",
33
+ "valuation_billions": "Valuation_Billions",
34
+ "date_joined": "Date_Joined",
35
+ "country": "Country",
36
+ "city": "City",
37
+ "industry": "Industry",
38
+ "select_investors": "Select_Investors"
39
+ })
40
+
41
+ # Parse the "Select_Investors" column to map investors to companies
42
  def build_investor_company_mapping(df):
43
  mapping = {}
44
  for _, row in df.iterrows():
45
  company = row["Company"]
46
+ investors = row["Select_Investors"]
47
  if pd.notnull(investors):
48
  for investor in investors.split(","):
49
  investor = investor.strip()
50
  mapping.setdefault(investor, []).append(company)
51
  return mapping
52
 
53
+ investor_company_mapping = build_investor_company_mapping(data)
54
+
55
  # Function to filter investors based on selected country and industry
56
  def filter_investors_by_country_and_industry(selected_country, selected_industry):
57
  filtered_data = data.copy()
 
66
  investor_valuations = {}
67
  for investor, companies in investor_company_mapping_filtered.items():
68
  total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
69
+ if total_valuation >= 20: # Investors with >= 20B total valuation
70
  investor_valuations[investor] = total_valuation
71
 
72
  return list(investor_valuations.keys()), filtered_data
 
90
  node_sizes = []
91
  for node in G.nodes:
92
  if node in filtered_mapping:
93
+ node_sizes.append(1500) # Fixed size for investors
94
  else:
95
  valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
96
+ size = (valuation / max_valuation) * 1500 if max_valuation else 100
97
  node_sizes.append(size)
98
 
99
  # Node color: Investors (orange), Companies (green)
 
108
  node_size=node_sizes,
109
  node_color=node_colors,
110
  font_size=10,
111
+ edge_color="#A9A9A9", # Light gray edges
112
  alpha=0.9
113
  )
114