LeonceNsh commited on
Commit
9e2bc99
·
verified ·
1 Parent(s): 2a73beb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -57
app.py CHANGED
@@ -4,6 +4,7 @@ import plotly.graph_objects as go
4
  import gradio as gr
5
  import re
6
  import logging
 
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
@@ -12,43 +13,48 @@ logger = logging.getLogger(__name__)
12
  # Load and preprocess the dataset
13
  file_path = "cbinsights_data.csv" # Replace with your actual file path
14
 
15
- try:
16
- data = pd.read_csv(file_path, skiprows=1)
17
- logger.info("CSV file loaded successfully.")
18
- except FileNotFoundError:
19
- logger.error(f"File not found: {file_path}")
20
- raise
21
- except Exception as e:
22
- logger.error(f"Error loading CSV file: {e}")
23
- raise
24
-
25
- # Standardize column names
26
- data.columns = data.columns.str.strip().str.lower()
27
- logger.info(f"Standardized Column Names: {data.columns.tolist()}")
28
-
29
- # Identify the valuation column
30
- valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
31
- if len(valuation_columns) != 1:
32
- logger.error("Unable to identify a single valuation column.")
33
- raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
34
-
35
- valuation_column = valuation_columns[0]
36
- logger.info(f"Using valuation column: {valuation_column}")
37
-
38
- # Clean and prepare data
39
- data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
40
- data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
41
- data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
42
- data.rename(columns={
43
- "company": "Company",
44
- "date_joined": "Date_Joined",
45
- "country": "Country",
46
- "city": "City",
47
- "industry": "Industry",
48
- "select_investors": "Select_Investors"
49
- }, inplace=True)
50
-
51
- logger.info("Data cleaned and columns renamed.")
 
 
 
 
 
52
 
53
  # Build investor-company mapping
54
  def build_investor_company_mapping(df):
@@ -73,11 +79,11 @@ def filter_investors(selected_country, selected_industry, selected_investors, se
73
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
74
  if selected_industry != "All":
75
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
 
 
76
  if selected_investors:
77
  pattern = '|'.join([re.escape(inv) for inv in selected_investors])
78
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
79
- if selected_company != "All":
80
- filtered_data = filtered_data[filtered_data["Company"] == selected_company]
81
 
82
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
83
  filtered_investors = list(investor_company_mapping_filtered.keys())
@@ -105,9 +111,9 @@ def generate_graph(investors, filtered_data):
105
  ]
106
  while num_colors > len(color_palette):
107
  color_palette.extend(color_palette)
108
-
109
  investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
110
-
111
  G = nx.Graph()
112
  for investor in investors:
113
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
@@ -115,17 +121,17 @@ def generate_graph(investors, filtered_data):
115
  G.add_node(company)
116
  G.add_node(investor)
117
  G.add_edge(investor, company)
118
-
119
  pos = nx.spring_layout(G, seed=42)
120
  edge_x = []
121
  edge_y = []
122
-
123
  for edge in G.edges():
124
  x0, y0 = pos[edge[0]]
125
  x1, y1 = pos[edge[1]]
126
  edge_x.extend([x0, x1, None])
127
  edge_y.extend([y0, y1, None])
128
-
129
  edge_trace = go.Scatter(
130
  x=edge_x,
131
  y=edge_y,
@@ -133,14 +139,14 @@ def generate_graph(investors, filtered_data):
133
  hoverinfo='none',
134
  mode='lines'
135
  )
136
-
137
  node_x = []
138
  node_y = []
139
  node_text = []
140
  node_color = []
141
  node_size = []
142
  node_hovertext = []
143
-
144
  for node in G.nodes():
145
  x, y = pos[node]
146
  node_x.append(x)
@@ -168,7 +174,7 @@ def generate_graph(investors, filtered_data):
168
  if len(valuation) > 0 and not pd.isnull(valuation[0]):
169
  hovertext += f"<br>Valuation: ${valuation[0]}B"
170
  node_hovertext.append(hovertext)
171
-
172
  node_trace = go.Scatter(
173
  x=node_x,
174
  y=node_y,
@@ -185,7 +191,7 @@ def generate_graph(investors, filtered_data):
185
  textposition="middle center",
186
  textfont=dict(size=12, color="#000000")
187
  )
188
-
189
  legend_items = []
190
  for investor in unique_investors:
191
  legend_items.append(
@@ -202,7 +208,7 @@ def generate_graph(investors, filtered_data):
202
  name=investor
203
  )
204
  )
205
-
206
  fig = go.Figure(data=legend_items + [edge_trace, node_trace])
207
  fig.update_layout(
208
  title="Venture Networks",
@@ -212,13 +218,13 @@ def generate_graph(investors, filtered_data):
212
  width=1200,
213
  height=800
214
  )
215
-
216
  fig.update_layout(
217
  autosize=True,
218
  xaxis={'showgrid': False, 'zeroline': False, 'visible': False},
219
  yaxis={'showgrid': False, 'zeroline': False, 'visible': False}
220
  )
221
-
222
  return fig
223
 
224
  # Gradio app
@@ -235,7 +241,7 @@ def main():
235
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
236
  company_list = ["All"] + sorted(data["Company"].dropna().unique())
237
  investor_list = sorted(investor_company_mapping.keys())
238
-
239
  with gr.Blocks(title="Venture Networks Visualization") as demo:
240
  gr.Markdown("""
241
  # Venture Networks Visualization
@@ -249,19 +255,29 @@ def main():
249
  with gr.Row():
250
  investor_output = gr.Textbox(label="Filtered Investors", interactive=False)
251
  graph_output = gr.Plot(label="Network Graph")
252
-
253
  inputs = [country_filter, industry_filter, company_filter, investor_filter]
254
  outputs = [investor_output, graph_output]
255
-
 
256
  country_filter.change(app, inputs, outputs)
257
  industry_filter.change(app, inputs, outputs)
258
  company_filter.change(app, inputs, outputs)
259
  investor_filter.change(app, inputs, outputs)
260
-
261
  gr.Markdown("""
262
- **Instructions:** Use the dropdowns to filter the network graph.
 
 
 
 
 
 
 
 
 
263
  """)
264
-
265
  demo.launch()
266
 
267
  if __name__ == "__main__":
 
4
  import gradio as gr
5
  import re
6
  import logging
7
+ import os
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO)
 
13
  # Load and preprocess the dataset
14
  file_path = "cbinsights_data.csv" # Replace with your actual file path
15
 
16
+ def load_data():
17
+ if not os.path.exists(file_path):
18
+ logger.error(f"File not found: {file_path}")
19
+ raise FileNotFoundError(f"File not found: {file_path}")
20
+
21
+ try:
22
+ data = pd.read_csv(file_path, skiprows=1)
23
+ logger.info("CSV file loaded successfully.")
24
+ except Exception as e:
25
+ logger.error(f"Error loading CSV file: {e}")
26
+ raise
27
+
28
+ # Standardize column names
29
+ data.columns = data.columns.str.strip().str.lower()
30
+ logger.info(f"Standardized Column Names: {data.columns.tolist()}")
31
+
32
+ # Identify the valuation column
33
+ valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
34
+ if len(valuation_columns) != 1:
35
+ logger.error("Unable to identify a single valuation column.")
36
+ raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
37
+
38
+ valuation_column = valuation_columns[0]
39
+ logger.info(f"Using valuation column: {valuation_column}")
40
+
41
+ # Clean and prepare data
42
+ data["Valuation_Billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
43
+ data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce')
44
+ data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
45
+ data.rename(columns={
46
+ "company": "Company",
47
+ "date_joined": "Date_Joined",
48
+ "country": "Country",
49
+ "city": "City",
50
+ "industry": "Industry",
51
+ "select_investors": "Select_Investors"
52
+ }, inplace=True)
53
+
54
+ logger.info("Data cleaned and columns renamed.")
55
+ return data
56
+
57
+ data = load_data()
58
 
59
  # Build investor-company mapping
60
  def build_investor_company_mapping(df):
 
79
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
80
  if selected_industry != "All":
81
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
82
+ if selected_company != "All":
83
+ filtered_data = filtered_data[filtered_data["Company"] == selected_company]
84
  if selected_investors:
85
  pattern = '|'.join([re.escape(inv) for inv in selected_investors])
86
  filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(pattern, na=False)]
 
 
87
 
88
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
89
  filtered_investors = list(investor_company_mapping_filtered.keys())
 
111
  ]
112
  while num_colors > len(color_palette):
113
  color_palette.extend(color_palette)
114
+
115
  investor_color_map = {investor: color_palette[i] for i, investor in enumerate(unique_investors)}
116
+
117
  G = nx.Graph()
118
  for investor in investors:
119
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(re.escape(investor), na=False)]["Company"].tolist()
 
121
  G.add_node(company)
122
  G.add_node(investor)
123
  G.add_edge(investor, company)
124
+
125
  pos = nx.spring_layout(G, seed=42)
126
  edge_x = []
127
  edge_y = []
128
+
129
  for edge in G.edges():
130
  x0, y0 = pos[edge[0]]
131
  x1, y1 = pos[edge[1]]
132
  edge_x.extend([x0, x1, None])
133
  edge_y.extend([y0, y1, None])
134
+
135
  edge_trace = go.Scatter(
136
  x=edge_x,
137
  y=edge_y,
 
139
  hoverinfo='none',
140
  mode='lines'
141
  )
142
+
143
  node_x = []
144
  node_y = []
145
  node_text = []
146
  node_color = []
147
  node_size = []
148
  node_hovertext = []
149
+
150
  for node in G.nodes():
151
  x, y = pos[node]
152
  node_x.append(x)
 
174
  if len(valuation) > 0 and not pd.isnull(valuation[0]):
175
  hovertext += f"<br>Valuation: ${valuation[0]}B"
176
  node_hovertext.append(hovertext)
177
+
178
  node_trace = go.Scatter(
179
  x=node_x,
180
  y=node_y,
 
191
  textposition="middle center",
192
  textfont=dict(size=12, color="#000000")
193
  )
194
+
195
  legend_items = []
196
  for investor in unique_investors:
197
  legend_items.append(
 
208
  name=investor
209
  )
210
  )
211
+
212
  fig = go.Figure(data=legend_items + [edge_trace, node_trace])
213
  fig.update_layout(
214
  title="Venture Networks",
 
218
  width=1200,
219
  height=800
220
  )
221
+
222
  fig.update_layout(
223
  autosize=True,
224
  xaxis={'showgrid': False, 'zeroline': False, 'visible': False},
225
  yaxis={'showgrid': False, 'zeroline': False, 'visible': False}
226
  )
227
+
228
  return fig
229
 
230
  # Gradio app
 
241
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
242
  company_list = ["All"] + sorted(data["Company"].dropna().unique())
243
  investor_list = sorted(investor_company_mapping.keys())
244
+
245
  with gr.Blocks(title="Venture Networks Visualization") as demo:
246
  gr.Markdown("""
247
  # Venture Networks Visualization
 
255
  with gr.Row():
256
  investor_output = gr.Textbox(label="Filtered Investors", interactive=False)
257
  graph_output = gr.Plot(label="Network Graph")
258
+
259
  inputs = [country_filter, industry_filter, company_filter, investor_filter]
260
  outputs = [investor_output, graph_output]
261
+
262
+ # Update the graph when any filter changes
263
  country_filter.change(app, inputs, outputs)
264
  industry_filter.change(app, inputs, outputs)
265
  company_filter.change(app, inputs, outputs)
266
  investor_filter.change(app, inputs, outputs)
267
+
268
  gr.Markdown("""
269
+ **Instructions:**
270
+ - **Country**: Filter companies by country.
271
+ - **Industry**: Filter companies by industry.
272
+ - **Company**: Select a specific company to focus on.
273
+ - **Select Investors**: Choose investors to visualize their network connections.
274
+
275
+ **Tips:**
276
+ - Hover over nodes to see more information.
277
+ - Use the legend to identify investor nodes.
278
+ - Adjust filters to refine your network view.
279
  """)
280
+
281
  demo.launch()
282
 
283
  if __name__ == "__main__":