LeonceNsh commited on
Commit
47c5bfd
·
verified ·
1 Parent(s): eedc3a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -51
app.py CHANGED
@@ -25,11 +25,21 @@ except Exception as e:
25
  data.columns = data.columns.str.strip().str.lower()
26
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
 
 
 
 
 
 
 
 
 
 
28
  # Clean and prepare data
 
 
29
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
30
  data.rename(columns={
31
  "company": "Company",
32
- "valuation": "Valuation",
33
  "date_joined": "Date_Joined",
34
  "country": "Country",
35
  "city": "City",
@@ -37,10 +47,6 @@ data.rename(columns={
37
  "select_investors": "Select_Investors"
38
  }, inplace=True)
39
 
40
- # Convert valuation to numeric for proportional node sizing
41
- data["Valuation"] = pd.to_numeric(
42
- data["Valuation"].replace({"\$": "", ",": ""}, regex=True), errors="coerce"
43
- )
44
  logger.info("Data cleaned and columns renamed.")
45
 
46
  # Build investor-company mapping
@@ -66,13 +72,10 @@ def filter_investors(selected_country, selected_industry, selected_investors):
66
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
67
  if selected_industry != "All":
68
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
69
- if selected_investors != ["All"]:
70
- filtered_data = filtered_data[
71
- filtered_data["Select_Investors"].apply(
72
- lambda x: any(inv in x for inv in selected_investors) if pd.notnull(x) else False
73
- )
74
- ]
75
-
76
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
77
  filtered_investors = list(investor_company_mapping_filtered.keys())
78
  return filtered_investors, filtered_data
@@ -83,78 +86,99 @@ def generate_graph(investors, filtered_data):
83
  logger.warning("No investors selected.")
84
  return go.Figure()
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  G = nx.Graph()
87
  for investor in investors:
88
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
89
  for company in companies:
 
 
90
  G.add_edge(investor, company)
91
-
92
  pos = nx.spring_layout(G, seed=42)
93
  edge_x = []
94
  edge_y = []
95
-
96
  for edge in G.edges():
97
  x0, y0 = pos[edge[0]]
98
  x1, y1 = pos[edge[1]]
99
  edge_x.extend([x0, x1, None])
100
  edge_y.extend([y0, y1, None])
101
-
102
  edge_trace = go.Scatter(
103
  x=edge_x,
104
  y=edge_y,
105
- line=dict(width=1, color="#888"),
106
- hoverinfo="none",
107
- mode="lines"
108
  )
109
-
110
  node_x = []
111
  node_y = []
112
  node_text = []
113
  node_color = []
114
  node_size = []
115
-
116
- # Color palette for investors (color blind friendly)
117
- investor_colors = [
118
- "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7"
119
- ]
120
-
121
- investor_color_map = {investor: investor_colors[i % len(investor_colors)] for i, investor in enumerate(investors)}
122
-
123
  for node in G.nodes():
124
  x, y = pos[node]
125
  node_x.append(x)
126
  node_y.append(y)
127
-
128
  if node in investors:
 
129
  node_text.append(node) # Label investors
130
- node_color.append(investor_color_map[node]) # Assign distinct colors to investors
131
- node_size.append(20) # Fixed size for investor nodes
132
  else:
133
- valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation"].sum()
 
 
 
 
 
 
 
 
134
  node_text.append("") # Hide company labels by default
135
- node_color.append("lightgreen") # Light green for companies
136
- node_size.append(max(10, valuation / 100)) # Size proportional to valuation
137
-
138
  node_trace = go.Scatter(
139
  x=node_x,
140
  y=node_y,
141
  text=node_text,
142
- mode="markers",
143
- hoverinfo="text",
144
  marker=dict(
145
  showscale=False,
146
  size=node_size,
147
  color=node_color,
148
  )
149
  )
150
-
151
  fig = go.Figure(data=[edge_trace, node_trace])
152
  fig.update_layout(
153
  showlegend=False,
154
  title="Venture Networks",
155
  titlefont_size=20,
156
  margin=dict(l=20, r=20, t=50, b=20),
157
- hovermode="closest",
158
  width=1200,
159
  height=800
160
  )
@@ -164,34 +188,31 @@ def generate_graph(investors, filtered_data):
164
  def app(selected_country, selected_industry, selected_investors):
165
  investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
166
  graph = generate_graph(investors, filtered_data)
167
- return investors, graph
168
 
169
  # Main function
170
  def main():
171
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
172
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
173
- investor_list = ["All"] + sorted(investor_company_mapping.keys())
174
 
175
  with gr.Blocks() as demo:
176
  with gr.Row():
177
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
178
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
179
- investor_filter = gr.CheckboxGroup(choices=investor_list, label="Investors", value=["All"])
180
 
181
  investor_output = gr.Textbox(label="Filtered Investors")
182
  graph_output = gr.Plot(label="Network Graph")
183
 
184
- country_filter.change(
185
- app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
186
- )
187
- industry_filter.change(
188
- app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
189
- )
190
- investor_filter.change(
191
- app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
192
- )
193
 
194
- demo.launch()
195
 
196
  if __name__ == "__main__":
197
  main()
 
25
  data.columns = data.columns.str.strip().str.lower()
26
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
 
28
+ # Identify the valuation column
29
+ valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
30
+ if len(valuation_columns) != 1:
31
+ logger.error("Unable to identify a single valuation column.")
32
+ raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
33
+
34
+ valuation_column = valuation_columns[0]
35
+ logger.info(f"Using valuation column: {valuation_column}")
36
+
37
  # Clean and prepare data
38
+ data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
39
+ data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
40
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
41
  data.rename(columns={
42
  "company": "Company",
 
43
  "date_joined": "Date_Joined",
44
  "country": "Country",
45
  "city": "City",
 
47
  "select_investors": "Select_Investors"
48
  }, inplace=True)
49
 
 
 
 
 
50
  logger.info("Data cleaned and columns renamed.")
51
 
52
  # Build investor-company mapping
 
72
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
73
  if selected_industry != "All":
74
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
75
+ if selected_investors:
76
+ pattern = '|'.join(selected_investors)
77
+ filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
78
+
 
 
 
79
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
80
  filtered_investors = list(investor_company_mapping_filtered.keys())
81
  return filtered_investors, filtered_data
 
86
  logger.warning("No investors selected.")
87
  return go.Figure()
88
 
89
+ # Create a color map for investors
90
+ unique_investors = investors
91
+ num_colors = len(unique_investors)
92
+ color_palette = [
93
+ "#000000", # black
94
+ "#E69F00", # orange
95
+ "#56B4E9", # sky blue
96
+ "#009E73", # bluish green
97
+ "#F0E442", # yellow
98
+ "#0072B2", # blue
99
+ "#D55E00", # vermillion
100
+ "#CC79A7", # reddish purple
101
+ ]
102
+ # Extend color_palette if necessary
103
+ while num_colors > len(color_palette):
104
+ color_palette.extend(color_palette)
105
+
106
+ investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
107
+
108
  G = nx.Graph()
109
  for investor in investors:
110
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
111
  for company in companies:
112
+ G.add_node(company)
113
+ G.add_node(investor)
114
  G.add_edge(investor, company)
115
+
116
  pos = nx.spring_layout(G, seed=42)
117
  edge_x = []
118
  edge_y = []
119
+
120
  for edge in G.edges():
121
  x0, y0 = pos[edge[0]]
122
  x1, y1 = pos[edge[1]]
123
  edge_x.extend([x0, x1, None])
124
  edge_y.extend([y0, y1, None])
125
+
126
  edge_trace = go.Scatter(
127
  x=edge_x,
128
  y=edge_y,
129
+ line=dict(width=1, color='#888'),
130
+ hoverinfo='none',
131
+ mode='lines'
132
  )
133
+
134
  node_x = []
135
  node_y = []
136
  node_text = []
137
  node_color = []
138
  node_size = []
139
+
 
 
 
 
 
 
 
140
  for node in G.nodes():
141
  x, y = pos[node]
142
  node_x.append(x)
143
  node_y.append(y)
 
144
  if node in investors:
145
+ # Investor node
146
  node_text.append(node) # Label investors
147
+ node_color.append(investor_color_map[node]) # Color assigned to investor
148
+ node_size.append(20) # Fixed size for investors
149
  else:
150
+ # Company node
151
+ valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].values
152
+ if len(valuation) > 0 and not pd.isnull(valuation[0]):
153
+ size = valuation[0] * 5 # Scale size as needed
154
+ if size < 5:
155
+ size = 5 # Minimum size
156
+ else:
157
+ size = 10 # Default size
158
+ node_size.append(size)
159
  node_text.append("") # Hide company labels by default
160
+ node_color.append("#b2df8a") # Light green color for companies
161
+
 
162
  node_trace = go.Scatter(
163
  x=node_x,
164
  y=node_y,
165
  text=node_text,
166
+ mode='markers',
167
+ hoverinfo='text',
168
  marker=dict(
169
  showscale=False,
170
  size=node_size,
171
  color=node_color,
172
  )
173
  )
174
+
175
  fig = go.Figure(data=[edge_trace, node_trace])
176
  fig.update_layout(
177
  showlegend=False,
178
  title="Venture Networks",
179
  titlefont_size=20,
180
  margin=dict(l=20, r=20, t=50, b=20),
181
+ hovermode='closest',
182
  width=1200,
183
  height=800
184
  )
 
188
  def app(selected_country, selected_industry, selected_investors):
189
  investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
190
  graph = generate_graph(investors, filtered_data)
191
+ return ', '.join(investors), graph
192
 
193
  # Main function
194
  def main():
195
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
196
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
197
+ investor_list = sorted(investor_company_mapping.keys())
198
 
199
  with gr.Blocks() as demo:
200
  with gr.Row():
201
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
202
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
203
+ investor_filter = gr.Dropdown(choices=investor_list, label="Investor", value=[], multiselect=True)
204
 
205
  investor_output = gr.Textbox(label="Filtered Investors")
206
  graph_output = gr.Plot(label="Network Graph")
207
 
208
+ inputs = [country_filter, industry_filter, investor_filter]
209
+ outputs = [investor_output, graph_output]
210
+
211
+ country_filter.change(app, inputs, outputs)
212
+ industry_filter.change(app, inputs, outputs)
213
+ investor_filter.change(app, inputs, outputs)
 
 
 
214
 
215
+ demo.launch()
216
 
217
  if __name__ == "__main__":
218
  main()