LeonceNsh commited on
Commit
cf8a69c
·
verified ·
1 Parent(s): e7cb59a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -62
app.py CHANGED
@@ -27,26 +27,18 @@ logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
 
28
  # Identify the valuation column
29
  valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
30
- if not valuation_columns:
31
- logger.error("No column containing 'Valuation' found in the dataset.")
32
- raise ValueError("Data Error: Unable to find the valuation column. Please check your CSV file.")
33
- elif len(valuation_columns) > 1:
34
- logger.error("Multiple columns containing 'Valuation' found in the dataset.")
35
- raise ValueError("Data Error: Multiple valuation columns detected. Please ensure only one valuation column exists.")
36
- else:
37
- valuation_column = valuation_columns[0]
38
- logger.info(f"Using valuation column: {valuation_column}")
39
 
40
  # Clean and prepare data
41
  data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
42
  data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
43
- logger.info("Valuation data cleaned and converted to numeric.")
44
-
45
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
46
- logger.info("Whitespace stripped from all string columns.")
47
-
48
- # Rename columns
49
- expected_columns = {
50
  "company": "Company",
51
  "valuation_billions": "Valuation_Billions",
52
  "date_joined": "Date_Joined",
@@ -54,17 +46,11 @@ expected_columns = {
54
  "city": "City",
55
  "industry": "Industry",
56
  "select_investors": "Select_Investors"
57
- }
58
 
59
- missing_columns = set(expected_columns.keys()) - set(data.columns)
60
- if missing_columns:
61
- logger.error(f"Missing columns in the dataset: {missing_columns}")
62
- raise ValueError(f"Data Error: Missing columns {missing_columns} in the dataset.")
63
 
64
- data = data.rename(columns=expected_columns)
65
- logger.info("Columns renamed for consistency.")
66
-
67
- # Build investor to company mapping
68
  def build_investor_company_mapping(df):
69
  mapping = {}
70
  for _, row in df.iterrows():
@@ -80,8 +66,8 @@ def build_investor_company_mapping(df):
80
  investor_company_mapping = build_investor_company_mapping(data)
81
  logger.info("Investor to company mapping created.")
82
 
83
- # Filter investors
84
- def filter_investors_by_country_and_industry(selected_country, selected_industry, valuation_threshold):
85
  filtered_data = data.copy()
86
  if selected_country != "All":
87
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
@@ -89,81 +75,106 @@ def filter_investors_by_country_and_industry(selected_country, selected_industry
89
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
90
 
91
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
92
- investor_valuations = {}
93
- for investor, companies in investor_company_mapping_filtered.items():
94
- total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
95
- if total_valuation >= valuation_threshold:
96
- investor_valuations[investor] = total_valuation
97
-
98
- return list(investor_valuations.keys()), filtered_data
99
 
100
  # Generate Plotly graph
101
- def generate_graph(investor_list, filtered_data):
102
- if not investor_list:
103
- logger.warning("No investors selected. Returning empty figure.")
104
  return go.Figure()
105
 
106
  G = nx.Graph()
107
- for investor in investor_list:
108
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
109
  for company in companies:
110
  G.add_edge(investor, company)
111
 
112
  pos = nx.spring_layout(G, seed=42)
 
 
 
 
 
 
 
 
 
113
  edge_trace = go.Scatter(
114
- x=[],
115
- y=[],
116
  line=dict(width=0.5, color='#888'),
117
  hoverinfo='none',
118
  mode='lines'
119
  )
120
 
121
- for edge in G.edges():
122
- x0, y0 = pos[edge[0]]
123
- x1, y1 = pos[edge[1]]
124
- edge_trace.x += [x0, x1, None]
125
- edge_trace.y += [y0, y1, None]
 
 
 
 
 
 
 
 
 
 
126
 
127
  node_trace = go.Scatter(
128
- x=[],
129
- y=[],
130
- text=[],
131
  mode='markers',
132
  hoverinfo='text',
133
  marker=dict(
134
  showscale=True,
135
  colorscale='YlGnBu',
136
  size=10,
137
- colorbar=dict(thickness=15, title="Node Value")
 
 
 
 
 
 
138
  )
139
  )
140
 
141
- for node in G.nodes():
142
- x, y = pos[node]
143
- node_trace.x += [x]
144
- node_trace.y += [y]
145
- node_trace.text += [node]
146
-
147
  fig = go.Figure(data=[edge_trace, node_trace])
 
 
 
 
 
 
 
148
  return fig
149
 
150
  # Gradio app
151
  def app(selected_country, selected_industry, valuation_threshold):
152
- investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry, valuation_threshold)
153
- graph = generate_graph(investor_list, filtered_data)
154
- return investor_list, graph
155
 
156
  def main():
157
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
158
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
159
-
160
  with gr.Blocks() as demo:
161
  with gr.Row():
162
- country_filter = gr.Dropdown(choices=country_list, label="Filter by Country", value="All")
163
- industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
164
- valuation_slider = gr.Slider(0, 50, value=20, label="Valuation Threshold (B)")
165
 
166
- investor_output = gr.Text(label="Investor List")
167
  graph_output = gr.Plot(label="Network Graph")
168
 
169
  country_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
 
27
 
28
  # Identify the valuation column
29
  valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
30
+ if len(valuation_columns) != 1:
31
+ logger.error("Unable to identify a single valuation column.")
32
+ raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
33
+
34
+ valuation_column = valuation_columns[0]
35
+ logger.info(f"Using valuation column: {valuation_column}")
 
 
 
36
 
37
  # Clean and prepare data
38
  data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
39
  data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
 
 
40
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
41
+ data.rename(columns={
 
 
 
42
  "company": "Company",
43
  "valuation_billions": "Valuation_Billions",
44
  "date_joined": "Date_Joined",
 
46
  "city": "City",
47
  "industry": "Industry",
48
  "select_investors": "Select_Investors"
49
+ }, inplace=True)
50
 
51
+ logger.info("Data cleaned and columns renamed.")
 
 
 
52
 
53
+ # Build investor-company mapping
 
 
 
54
  def build_investor_company_mapping(df):
55
  mapping = {}
56
  for _, row in df.iterrows():
 
66
  investor_company_mapping = build_investor_company_mapping(data)
67
  logger.info("Investor to company mapping created.")
68
 
69
+ # Filter investors by country, industry, and valuation threshold
70
+ def filter_investors(selected_country, selected_industry, valuation_threshold):
71
  filtered_data = data.copy()
72
  if selected_country != "All":
73
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
 
75
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
76
 
77
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
78
+ investor_valuations = {
79
+ investor: filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
80
+ for investor, companies in investor_company_mapping_filtered.items()
81
+ }
82
+ filtered_investors = [investor for investor, total in investor_valuations.items() if total >= valuation_threshold]
83
+ return filtered_investors, filtered_data
 
84
 
85
  # Generate Plotly graph
86
+ def generate_graph(investors, filtered_data):
87
+ if not investors:
88
+ logger.warning("No investors selected.")
89
  return go.Figure()
90
 
91
  G = nx.Graph()
92
+ for investor in investors:
93
  companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
94
  for company in companies:
95
  G.add_edge(investor, company)
96
 
97
  pos = nx.spring_layout(G, seed=42)
98
+ edge_x = []
99
+ edge_y = []
100
+
101
+ for edge in G.edges():
102
+ x0, y0 = pos[edge[0]]
103
+ x1, y1 = pos[edge[1]]
104
+ edge_x.extend([x0, x1, None])
105
+ edge_y.extend([y0, y1, None])
106
+
107
  edge_trace = go.Scatter(
108
+ x=edge_x,
109
+ y=edge_y,
110
  line=dict(width=0.5, color='#888'),
111
  hoverinfo='none',
112
  mode='lines'
113
  )
114
 
115
+ node_x = []
116
+ node_y = []
117
+ node_text = []
118
+ node_color = []
119
+
120
+ for node in G.nodes():
121
+ x, y = pos[node]
122
+ node_x.append(x)
123
+ node_y.append(y)
124
+ node_text.append(node)
125
+ if node in investors:
126
+ node_color.append(20) # Fixed color value for investors
127
+ else:
128
+ valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
129
+ node_color.append(valuation)
130
 
131
  node_trace = go.Scatter(
132
+ x=node_x,
133
+ y=node_y,
134
+ text=node_text,
135
  mode='markers',
136
  hoverinfo='text',
137
  marker=dict(
138
  showscale=True,
139
  colorscale='YlGnBu',
140
  size=10,
141
+ color=node_color,
142
+ colorbar=dict(
143
+ thickness=15,
144
+ title="Valuation (B)",
145
+ xanchor='left',
146
+ titleside='right'
147
+ )
148
  )
149
  )
150
 
 
 
 
 
 
 
151
  fig = go.Figure(data=[edge_trace, node_trace])
152
+ fig.update_layout(
153
+ showlegend=False,
154
+ title="Venture Networks",
155
+ titlefont_size=16,
156
+ margin=dict(l=40, r=40, t=40, b=40),
157
+ hovermode='closest'
158
+ )
159
  return fig
160
 
161
  # Gradio app
162
  def app(selected_country, selected_industry, valuation_threshold):
163
+ investors, filtered_data = filter_investors(selected_country, selected_industry, valuation_threshold)
164
+ graph = generate_graph(investors, filtered_data)
165
+ return investors, graph
166
 
167
  def main():
168
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
169
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
170
+
171
  with gr.Blocks() as demo:
172
  with gr.Row():
173
+ country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
174
+ industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
175
+ valuation_slider = gr.Slider(0, 50, value=20, step=1, label="Valuation Threshold (B)")
176
 
177
+ investor_output = gr.Textbox(label="Filtered Investors")
178
  graph_output = gr.Plot(label="Network Graph")
179
 
180
  country_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])