LeonceNsh commited on
Commit
e7cb59a
·
verified ·
1 Parent(s): 323aee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -53
app.py CHANGED
@@ -21,11 +21,11 @@ except Exception as e:
21
  logger.error(f"Error loading CSV file: {e}")
22
  raise
23
 
24
- # Standardize column names: strip whitespace and convert to lowercase
25
  data.columns = data.columns.str.strip().str.lower()
26
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
 
28
- # Identify the valuation column dynamically
29
  valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
30
  if not valuation_columns:
31
  logger.error("No column containing 'Valuation' found in the dataset.")
@@ -42,11 +42,10 @@ data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''},
42
  data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
43
  logger.info("Valuation data cleaned and converted to numeric.")
44
 
45
- # Strip whitespace from all string columns
46
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
47
  logger.info("Whitespace stripped from all string columns.")
48
 
49
- # Rename columns for consistency
50
  expected_columns = {
51
  "company": "Company",
52
  "valuation_billions": "Valuation_Billions",
@@ -65,7 +64,7 @@ if missing_columns:
65
  data = data.rename(columns=expected_columns)
66
  logger.info("Columns renamed for consistency.")
67
 
68
- # Parse the "Select_Investors" column to map investors to companies
69
  def build_investor_company_mapping(df):
70
  mapping = {}
71
  for _, row in df.iterrows():
@@ -74,39 +73,31 @@ def build_investor_company_mapping(df):
74
  if pd.notnull(investors):
75
  for investor in investors.split(","):
76
  investor = investor.strip()
77
- if investor: # Ensure investor is not an empty string
78
  mapping.setdefault(investor, []).append(company)
79
  return mapping
80
 
81
  investor_company_mapping = build_investor_company_mapping(data)
82
  logger.info("Investor to company mapping created.")
83
 
84
- # Function to filter investors based on selected country and industry
85
  def filter_investors_by_country_and_industry(selected_country, selected_industry, valuation_threshold):
86
  filtered_data = data.copy()
87
- logger.info(f"Filtering data for Country: {selected_country}, Industry: {selected_industry}")
88
-
89
  if selected_country != "All":
90
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
91
- logger.info(f"Data filtered by country: {selected_country}. Remaining records: {len(filtered_data)}")
92
  if selected_industry != "All":
93
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
94
- logger.info(f"Data filtered by industry: {selected_industry}. Remaining records: {len(filtered_data)}")
95
-
96
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
97
-
98
- # Calculate total valuation per investor
99
  investor_valuations = {}
100
  for investor, companies in investor_company_mapping_filtered.items():
101
  total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
102
  if total_valuation >= valuation_threshold:
103
  investor_valuations[investor] = total_valuation
104
-
105
- logger.info(f"Filtered investors with total valuation >= {valuation_threshold}B: {len(investor_valuations)}")
106
-
107
  return list(investor_valuations.keys()), filtered_data
108
 
109
- # Function to generate the Plotly graph
110
  def generate_graph(investor_list, filtered_data):
111
  if not investor_list:
112
  logger.warning("No investors selected. Returning empty figure.")
@@ -118,9 +109,7 @@ def generate_graph(investor_list, filtered_data):
118
  for company in companies:
119
  G.add_edge(investor, company)
120
 
121
- pos = nx.spring_layout(G, k=0.2, seed=42)
122
-
123
- # Create Plotly traces for edges and nodes
124
  edge_trace = go.Scatter(
125
  x=[],
126
  y=[],
@@ -132,8 +121,8 @@ def generate_graph(investor_list, filtered_data):
132
  for edge in G.edges():
133
  x0, y0 = pos[edge[0]]
134
  x1, y1 = pos[edge[1]]
135
- edge_trace['x'] += [x0, x1, None]
136
- edge_trace['y'] += [y0, y1, None]
137
 
138
  node_trace = go.Scatter(
139
  x=[],
@@ -145,58 +134,42 @@ def generate_graph(investor_list, filtered_data):
145
  showscale=True,
146
  colorscale='YlGnBu',
147
  size=10,
148
- colorbar=dict(thickness=15, title='Node Valuation')
149
  )
150
  )
151
 
152
  for node in G.nodes():
153
  x, y = pos[node]
154
- node_trace['x'] += [x]
155
- node_trace['y'] += [y]
156
- node_trace['text'] += [f"{node}"]
157
 
158
  fig = go.Figure(data=[edge_trace, node_trace])
159
  return fig
160
 
161
- # Gradio app function
162
  def app(selected_country, selected_industry, valuation_threshold):
163
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry, valuation_threshold)
164
  graph = generate_graph(investor_list, filtered_data)
165
  return investor_list, graph
166
 
167
- # Gradio Interface
168
  def main():
169
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
170
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
171
 
172
- logger.info(f"Available countries: {country_list}")
173
- logger.info(f"Available industries: {industry_list}")
174
-
175
  with gr.Blocks() as demo:
176
  with gr.Row():
177
  country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
178
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
179
- valuation_threshold = gr.Slider(minimum=0, maximum=50, step=1, value=20, label="Valuation Threshold (in B)")
180
-
181
- investor_output = gr.Text(label="Investor Results")
182
- graph_output = gr.Plot(label="Venture Network Graph")
183
-
184
- country_filter.change(
185
- app,
186
- inputs=[country_filter, industry_filter, valuation_threshold],
187
- outputs=[investor_output, graph_output]
188
- )
189
- industry_filter.change(
190
- app,
191
- inputs=[country_filter, industry_filter, valuation_threshold],
192
- outputs=[investor_output, graph_output]
193
- )
194
- valuation_threshold.change(
195
- app,
196
- inputs=[country_filter, industry_filter, valuation_threshold],
197
- outputs=[investor_output, graph_output]
198
- )
199
-
200
  demo.launch()
201
 
202
  if __name__ == "__main__":
 
21
  logger.error(f"Error loading CSV file: {e}")
22
  raise
23
 
24
+ # Standardize column names
25
  data.columns = data.columns.str.strip().str.lower()
26
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
 
28
+ # Identify the valuation column
29
  valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
30
  if not valuation_columns:
31
  logger.error("No column containing 'Valuation' found in the dataset.")
 
42
  data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
43
  logger.info("Valuation data cleaned and converted to numeric.")
44
 
 
45
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
46
  logger.info("Whitespace stripped from all string columns.")
47
 
48
+ # Rename columns
49
  expected_columns = {
50
  "company": "Company",
51
  "valuation_billions": "Valuation_Billions",
 
64
  data = data.rename(columns=expected_columns)
65
  logger.info("Columns renamed for consistency.")
66
 
67
+ # Build investor to company mapping
68
  def build_investor_company_mapping(df):
69
  mapping = {}
70
  for _, row in df.iterrows():
 
73
  if pd.notnull(investors):
74
  for investor in investors.split(","):
75
  investor = investor.strip()
76
+ if investor:
77
  mapping.setdefault(investor, []).append(company)
78
  return mapping
79
 
80
  investor_company_mapping = build_investor_company_mapping(data)
81
  logger.info("Investor to company mapping created.")
82
 
83
+ # Filter investors
84
  def filter_investors_by_country_and_industry(selected_country, selected_industry, valuation_threshold):
85
  filtered_data = data.copy()
 
 
86
  if selected_country != "All":
87
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
 
88
  if selected_industry != "All":
89
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
90
+
 
91
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
 
 
92
  investor_valuations = {}
93
  for investor, companies in investor_company_mapping_filtered.items():
94
  total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
95
  if total_valuation >= valuation_threshold:
96
  investor_valuations[investor] = total_valuation
97
+
 
 
98
  return list(investor_valuations.keys()), filtered_data
99
 
100
+ # Generate Plotly graph
101
  def generate_graph(investor_list, filtered_data):
102
  if not investor_list:
103
  logger.warning("No investors selected. Returning empty figure.")
 
109
  for company in companies:
110
  G.add_edge(investor, company)
111
 
112
+ pos = nx.spring_layout(G, seed=42)
 
 
113
  edge_trace = go.Scatter(
114
  x=[],
115
  y=[],
 
121
  for edge in G.edges():
122
  x0, y0 = pos[edge[0]]
123
  x1, y1 = pos[edge[1]]
124
+ edge_trace.x += [x0, x1, None]
125
+ edge_trace.y += [y0, y1, None]
126
 
127
  node_trace = go.Scatter(
128
  x=[],
 
134
  showscale=True,
135
  colorscale='YlGnBu',
136
  size=10,
137
+ colorbar=dict(thickness=15, title="Node Value")
138
  )
139
  )
140
 
141
  for node in G.nodes():
142
  x, y = pos[node]
143
+ node_trace.x += [x]
144
+ node_trace.y += [y]
145
+ node_trace.text += [node]
146
 
147
  fig = go.Figure(data=[edge_trace, node_trace])
148
  return fig
149
 
150
+ # Gradio app
151
  def app(selected_country, selected_industry, valuation_threshold):
152
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry, valuation_threshold)
153
  graph = generate_graph(investor_list, filtered_data)
154
  return investor_list, graph
155
 
 
156
  def main():
157
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
158
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
159
 
 
 
 
160
  with gr.Blocks() as demo:
161
  with gr.Row():
162
  country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
163
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
164
+ valuation_slider = gr.Slider(0, 50, value=20, label="Valuation Threshold (B)")
165
+
166
+ investor_output = gr.Text(label="Investor List")
167
+ graph_output = gr.Plot(label="Network Graph")
168
+
169
+ country_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
170
+ industry_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
171
+ valuation_slider.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
172
+
 
 
 
 
 
 
 
 
 
 
 
 
173
  demo.launch()
174
 
175
  if __name__ == "__main__":