LeonceNsh commited on
Commit
9aa537c
·
verified ·
1 Parent(s): 09b019d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -76
app.py CHANGED
@@ -9,124 +9,107 @@ import gradio as gr
9
  file_path = "cbinsights_data.csv" # Replace with your file path
10
  data = pd.read_csv(file_path)
11
 
12
- # Rename columns based on the first row and drop the header row
13
- data.columns = data.iloc[0]
14
- data = data[1:]
15
- data.columns = ["Company", "Valuation_Billions", "Date_Joined", "Country", "City", "Industry", "Select_Investors"]
16
-
17
  # Clean and prepare data
18
- data["Valuation_Billions"] = data["Valuation_Billions"].str.replace('$', '').str.split('.').str[0]
 
19
  data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
20
  data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x)
21
 
22
- # Parse the "Select_Investors" column to map investors to companies
23
- investor_company_mapping = {}
24
- for _, row in data.iterrows():
25
- company = row["Company"]
26
- investors = row["Select_Investors"]
27
- if pd.notnull(investors):
28
- for investor in investors.split(","):
29
- investor = investor.strip()
30
- if investor not in investor_company_mapping:
31
- investor_company_mapping[investor] = []
32
- investor_company_mapping[investor].append(company)
33
-
34
- # Gradio app functions
35
  def filter_investors_by_country_and_industry(selected_country, selected_industry):
36
- filtered_data = data
37
-
38
- # Apply filters
39
  if selected_country != "All":
40
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
41
  if selected_industry != "All":
42
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
43
-
 
 
44
  # Calculate total valuation per investor
45
  investor_valuations = {}
46
- for investor, companies in investor_company_mapping.items():
47
- total_valuation = 0
48
- for company in companies:
49
- if company in filtered_data["Company"].values:
50
- valuation = filtered_data.loc[filtered_data["Company"] == company, "Valuation_Billions"].values
51
- total_valuation += valuation[0] if len(valuation) > 0 else 0
52
- if total_valuation >= 20: # Filter by total valuation
53
  investor_valuations[investor] = total_valuation
54
-
55
  return list(investor_valuations.keys()), filtered_data
56
 
 
57
  def generate_graph(selected_investors, filtered_data):
58
  if not selected_investors:
59
  return None
60
 
61
- # Filter the investor-to-company mapping
62
- filtered_mapping = {}
63
- for investor, companies in investor_company_mapping.items():
64
- if investor in selected_investors:
65
- filtered_companies = [c for c in companies if c in filtered_data["Company"].values]
66
- if filtered_companies:
67
- filtered_mapping[investor] = filtered_companies
68
 
69
- # Use the filtered mapping to build the graph
70
  G = nx.Graph()
71
  for investor, companies in filtered_mapping.items():
72
  for company in companies:
73
  G.add_edge(investor, company)
74
 
75
- # Node sizes based on valuation
 
76
  node_sizes = []
77
- for node in G.nodes:
78
- if node in filtered_mapping: # Fixed size for investors
79
- node_sizes.append(2000)
80
- else: # Company size based on valuation
81
- valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
82
- node_sizes.append(valuation[0] * 50 if len(valuation) > 0 else 100)
83
-
84
- # Node colors
85
- node_colors = []
86
  for node in G.nodes:
87
  if node in filtered_mapping:
88
- node_colors.append("#FF5733") # Distinct color for investors
89
  else:
90
- node_colors.append("#33FF57") # Distinct color for companies
 
 
 
 
 
91
 
92
- # Create the graph plot
93
- plt.figure(figsize=(18, 18))
94
- pos = nx.spring_layout(G, k=0.2, seed=42) # Fixed seed for consistent layout
95
  nx.draw(
96
  G, pos,
97
  with_labels=True,
98
  node_size=node_sizes,
99
  node_color=node_colors,
100
- alpha=0.8, # Slight transparency for Tufte-inspired visuals
101
  font_size=10,
102
- font_weight="bold",
103
- edge_color="#B0BEC5", # Neutral, muted edge color
104
- width=0.8 # Thin edges for minimal visual clutter
105
  )
106
 
107
- # Add a legend for node size (valuation)
108
- min_size, max_size = 50, 5000 # Example scale
109
- for size, label in zip([min_size, max_size], ["$1B", "$100B"]):
110
- plt.scatter([], [], s=size, color="#33FF57", label=f"{label} valuation")
111
- plt.legend(scatterpoints=1, frameon=False, labelspacing=1.5, loc="lower left", fontsize=12)
 
 
112
 
113
- plt.title("Venture Funded Companies Visualization", fontsize=20)
114
- plt.axis('off')
115
 
116
- # Save plot to BytesIO object
117
  buf = BytesIO()
118
  plt.savefig(buf, format="png", bbox_inches="tight")
119
  plt.close()
120
  buf.seek(0)
121
 
122
- # Convert BytesIO to PIL image
123
- image = Image.open(buf)
124
- return image
125
 
 
126
  def app(selected_country, selected_industry):
127
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
128
-
129
- return gr.update(
130
  choices=investor_list,
131
  value=investor_list,
132
  visible=True
@@ -142,11 +125,7 @@ def main():
142
  country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
143
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
144
 
145
- filtered_investor_list = gr.CheckboxGroup(
146
- choices=[],
147
- label="Select Investors",
148
- visible=False
149
- )
150
  graph_output = gr.Image(type="pil", label="Venture Network Graph")
151
 
152
  filtered_data_holder = gr.State()
 
9
  file_path = "cbinsights_data.csv" # Replace with your file path
10
  data = pd.read_csv(file_path)
11
 
 
 
 
 
 
12
  # Clean and prepare data
13
+ data.columns = data.columns.str.strip()
14
+ data["Valuation_Billions"] = data["Valuation ($B)"].replace({'\$': '', ',': ''}, regex=True)
15
  data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
16
  data = data.applymap(lambda x: x.strip() if isinstance(x, str) else x)
17
 
18
+ # Parse the "Select Investors" column to map investors to companies
19
+ def build_investor_company_mapping(df):
20
+ mapping = {}
21
+ for _, row in df.iterrows():
22
+ company = row["Company"]
23
+ investors = row["Select Investors"]
24
+ if pd.notnull(investors):
25
+ for investor in investors.split(","):
26
+ investor = investor.strip()
27
+ mapping.setdefault(investor, []).append(company)
28
+ return mapping
29
+
30
+ # Function to filter investors based on selected country and industry
31
  def filter_investors_by_country_and_industry(selected_country, selected_industry):
32
+ filtered_data = data.copy()
 
 
33
  if selected_country != "All":
34
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
35
  if selected_industry != "All":
36
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
37
+
38
+ investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
39
+
40
  # Calculate total valuation per investor
41
  investor_valuations = {}
42
+ for investor, companies in investor_company_mapping_filtered.items():
43
+ total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
44
+ if total_valuation >= 20:
 
 
 
 
45
  investor_valuations[investor] = total_valuation
46
+
47
  return list(investor_valuations.keys()), filtered_data
48
 
49
+ # Function to generate the graph
50
  def generate_graph(selected_investors, filtered_data):
51
  if not selected_investors:
52
  return None
53
 
54
+ investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
55
+ filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors}
 
 
 
 
 
56
 
57
+ # Build the graph
58
  G = nx.Graph()
59
  for investor, companies in filtered_mapping.items():
60
  for company in companies:
61
  G.add_edge(investor, company)
62
 
63
+ # Node size based on valuation
64
+ max_valuation = filtered_data["Valuation_Billions"].max()
65
  node_sizes = []
 
 
 
 
 
 
 
 
 
66
  for node in G.nodes:
67
  if node in filtered_mapping:
68
+ node_sizes.append(1500)
69
  else:
70
+ valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
71
+ size = (valuation / max_valuation) * 1500
72
+ node_sizes.append(size)
73
+
74
+ # Node color: Investors (orange), Companies (green)
75
+ node_colors = ["#FF8C00" if node in filtered_mapping else "#32CD32" for node in G.nodes]
76
 
77
+ # Draw the graph
78
+ plt.figure(figsize=(15, 15))
79
+ pos = nx.spring_layout(G, k=0.2, seed=42)
80
  nx.draw(
81
  G, pos,
82
  with_labels=True,
83
  node_size=node_sizes,
84
  node_color=node_colors,
 
85
  font_size=10,
86
+ edge_color="#A9A9A9",
87
+ alpha=0.9
 
88
  )
89
 
90
+ # Legend
91
+ from matplotlib.lines import Line2D
92
+ legend_elements = [
93
+ Line2D([0], [0], marker='o', color='w', label='Investor', markersize=10, markerfacecolor='#FF8C00'),
94
+ Line2D([0], [0], marker='o', color='w', label='Company', markersize=10, markerfacecolor='#32CD32')
95
+ ]
96
+ plt.legend(handles=legend_elements, loc='upper left')
97
 
98
+ plt.title("Venture Network Visualization", fontsize=20)
99
+ plt.axis("off")
100
 
101
+ # Save plot to BytesIO
102
  buf = BytesIO()
103
  plt.savefig(buf, format="png", bbox_inches="tight")
104
  plt.close()
105
  buf.seek(0)
106
 
107
+ return Image.open(buf)
 
 
108
 
109
+ # Gradio app function
110
  def app(selected_country, selected_industry):
111
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
112
+ return gr.CheckboxGroup.update(
 
113
  choices=investor_list,
114
  value=investor_list,
115
  visible=True
 
125
  country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
126
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
127
 
128
+ filtered_investor_list = gr.CheckboxGroup(choices=[], label="Select Investors", visible=False)
 
 
 
 
129
  graph_output = gr.Image(type="pil", label="Venture Network Graph")
130
 
131
  filtered_data_holder = gr.State()