LeonceNsh commited on
Commit
f7d5c33
·
verified ·
1 Parent(s): 22d1cb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -46
app.py CHANGED
@@ -25,22 +25,10 @@ except Exception as e:
25
  data.columns = data.columns.str.strip().str.lower()
26
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
 
28
- # Identify the valuation column
29
- valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
30
- if len(valuation_columns) != 1:
31
- logger.error("Unable to identify a single valuation column.")
32
- raise ValueError("Dataset should contain exactly one column with 'valuation' in its name.")
33
-
34
- valuation_column = valuation_columns[0]
35
- logger.info(f"Using valuation column: {valuation_column}")
36
-
37
  # Clean and prepare data
38
- data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
39
- data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
40
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
41
  data.rename(columns={
42
  "company": "Company",
43
- "valuation_billions": "Valuation_Billions",
44
  "date_joined": "Date_Joined",
45
  "country": "Country",
46
  "city": "City",
@@ -66,8 +54,8 @@ def build_investor_company_mapping(df):
66
  investor_company_mapping = build_investor_company_mapping(data)
67
  logger.info("Investor to company mapping created.")
68
 
69
- # Filter investors by country, industry, and valuation threshold
70
- def filter_investors(selected_country, selected_industry, valuation_threshold):
71
  filtered_data = data.copy()
72
  if selected_country != "All":
73
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
@@ -75,14 +63,10 @@ def filter_investors(selected_country, selected_industry, valuation_threshold):
75
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
76
 
77
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
78
- investor_valuations = {
79
- investor: filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
80
- for investor, companies in investor_company_mapping_filtered.items()
81
- }
82
- filtered_investors = [investor for investor, total in investor_valuations.items() if total >= valuation_threshold]
83
  return filtered_investors, filtered_data
84
 
85
- # Generate Plotly graph
86
  def generate_graph(investors, filtered_data):
87
  if not investors:
88
  logger.warning("No investors selected.")
@@ -107,7 +91,7 @@ def generate_graph(investors, filtered_data):
107
  edge_trace = go.Scatter(
108
  x=edge_x,
109
  y=edge_y,
110
- line=dict(width=0.5, color='#888'),
111
  hoverinfo='none',
112
  mode='lines'
113
  )
@@ -122,48 +106,41 @@ def generate_graph(investors, filtered_data):
122
  node_x.append(x)
123
  node_y.append(y)
124
  node_text.append(node)
125
- if node in investors:
126
- node_color.append(20) # Fixed color value for investors
127
- else:
128
- valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
129
- node_color.append(valuation)
130
 
131
  node_trace = go.Scatter(
132
  x=node_x,
133
  y=node_y,
134
  text=node_text,
135
- mode='markers',
136
  hoverinfo='text',
137
  marker=dict(
138
- showscale=True,
139
- colorscale='YlGnBu',
140
- size=13,
141
  color=node_color,
142
- colorbar=dict(
143
- thickness=15,
144
- title="Valuation (B)",
145
- xanchor='left',
146
- titleside='right'
147
- )
148
- )
149
  )
150
 
151
  fig = go.Figure(data=[edge_trace, node_trace])
152
  fig.update_layout(
153
  showlegend=False,
154
  title="Venture Networks",
155
- titlefont_size=16,
156
- margin=dict(l=40, r=40, t=40, b=40),
157
- hovermode='closest'
 
 
158
  )
159
  return fig
160
 
161
  # Gradio app
162
- def app(selected_country, selected_industry, valuation_threshold):
163
- investors, filtered_data = filter_investors(selected_country, selected_industry, valuation_threshold)
164
  graph = generate_graph(investors, filtered_data)
165
  return investors, graph
166
 
 
167
  def main():
168
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
169
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
@@ -172,14 +149,12 @@ def main():
172
  with gr.Row():
173
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
174
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
175
- valuation_slider = gr.Slider(0, 50, value=20, step=1, label="Valuation Threshold (B)")
176
 
177
  investor_output = gr.Textbox(label="Filtered Investors")
178
  graph_output = gr.Plot(label="Network Graph")
179
 
180
- country_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
181
- industry_filter.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
182
- valuation_slider.change(app, [country_filter, industry_filter, valuation_slider], [investor_output, graph_output])
183
 
184
  demo.launch()
185
 
 
25
  data.columns = data.columns.str.strip().str.lower()
26
  logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
 
 
 
 
 
 
 
 
 
 
28
  # Clean and prepare data
 
 
29
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
30
  data.rename(columns={
31
  "company": "Company",
 
32
  "date_joined": "Date_Joined",
33
  "country": "Country",
34
  "city": "City",
 
54
  investor_company_mapping = build_investor_company_mapping(data)
55
  logger.info("Investor to company mapping created.")
56
 
57
+ # Filter investors by country and industry (removed valuation threshold)
58
+ def filter_investors(selected_country, selected_industry):
59
  filtered_data = data.copy()
60
  if selected_country != "All":
61
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
 
63
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
64
 
65
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
66
+ filtered_investors = list(investor_company_mapping_filtered.keys())
 
 
 
 
67
  return filtered_investors, filtered_data
68
 
69
+ # Generate Plotly graph with increased size and improved styling
70
  def generate_graph(investors, filtered_data):
71
  if not investors:
72
  logger.warning("No investors selected.")
 
91
  edge_trace = go.Scatter(
92
  x=edge_x,
93
  y=edge_y,
94
+ line=dict(width=1, color='#888'),
95
  hoverinfo='none',
96
  mode='lines'
97
  )
 
106
  node_x.append(x)
107
  node_y.append(y)
108
  node_text.append(node)
109
+ node_color.append(10) # Use a fixed color or other logic
 
 
 
 
110
 
111
  node_trace = go.Scatter(
112
  x=node_x,
113
  y=node_y,
114
  text=node_text,
115
+ mode='markers+text',
116
  hoverinfo='text',
117
  marker=dict(
118
+ showscale=False,
119
+ size=15, # Increased size
 
120
  color=node_color,
121
+ ),
122
+ textposition="top center" # Improved label positioning
 
 
 
 
 
123
  )
124
 
125
  fig = go.Figure(data=[edge_trace, node_trace])
126
  fig.update_layout(
127
  showlegend=False,
128
  title="Venture Networks",
129
+ titlefont_size=20,
130
+ margin=dict(l=20, r=20, t=50, b=20),
131
+ hovermode='closest',
132
+ width=1200, # Increased width
133
+ height=800 # Increased height
134
  )
135
  return fig
136
 
137
  # Gradio app
138
+ def app(selected_country, selected_industry):
139
+ investors, filtered_data = filter_investors(selected_country, selected_industry)
140
  graph = generate_graph(investors, filtered_data)
141
  return investors, graph
142
 
143
+ # Main function
144
  def main():
145
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
146
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
 
149
  with gr.Row():
150
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
151
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
 
152
 
153
  investor_output = gr.Textbox(label="Filtered Investors")
154
  graph_output = gr.Plot(label="Network Graph")
155
 
156
+ country_filter.change(app, [country_filter, industry_filter], [investor_output, graph_output])
157
+ industry_filter.change(app, [country_filter, industry_filter], [investor_output, graph_output])
 
158
 
159
  demo.launch()
160