LeonceNsh commited on
Commit
0de2f41
·
verified ·
1 Parent(s): cc1818e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -16
app.py CHANGED
@@ -54,19 +54,23 @@ def build_investor_company_mapping(df):
54
  investor_company_mapping = build_investor_company_mapping(data)
55
  logger.info("Investor to company mapping created.")
56
 
57
- # Filter investors by country and industry (removed valuation threshold)
58
- def filter_investors(selected_country, selected_industry):
59
  filtered_data = data.copy()
60
  if selected_country != "All":
61
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
62
  if selected_industry != "All":
63
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
 
 
 
 
64
 
65
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
66
  filtered_investors = list(investor_company_mapping_filtered.keys())
67
  return filtered_investors, filtered_data
68
 
69
- # Generate Plotly graph with increased size and improved styling
70
  def generate_graph(investors, filtered_data):
71
  if not investors:
72
  logger.warning("No investors selected.")
@@ -105,21 +109,24 @@ def generate_graph(investors, filtered_data):
105
  x, y = pos[node]
106
  node_x.append(x)
107
  node_y.append(y)
108
- node_text.append(node)
109
- node_color.append(10) # Use a fixed color or other logic
 
 
 
 
110
 
111
  node_trace = go.Scatter(
112
  x=node_x,
113
  y=node_y,
114
  text=node_text,
115
- mode='markers+text',
116
  hoverinfo='text',
117
  marker=dict(
118
  showscale=False,
119
- size=15, # Increased size
120
  color=node_color,
121
- ),
122
- textposition="top center" # Improved label positioning
123
  )
124
 
125
  fig = go.Figure(data=[edge_trace, node_trace])
@@ -129,14 +136,14 @@ def generate_graph(investors, filtered_data):
129
  titlefont_size=20,
130
  margin=dict(l=20, r=20, t=50, b=20),
131
  hovermode='closest',
132
- width=1200, # Increased width
133
- height=800 # Increased height
134
  )
135
  return fig
136
 
137
- # Update the Gradio app to remove valuation threshold
138
- def app(selected_country, selected_industry):
139
- investors, filtered_data = filter_investors(selected_country, selected_industry)
140
  graph = generate_graph(investors, filtered_data)
141
  return investors, graph
142
 
@@ -144,17 +151,26 @@ def app(selected_country, selected_industry):
144
  def main():
145
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
146
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
 
147
 
148
  with gr.Blocks() as demo:
149
  with gr.Row():
150
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
151
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
 
152
 
153
  investor_output = gr.Textbox(label="Filtered Investors")
154
  graph_output = gr.Plot(label="Network Graph")
155
 
156
- country_filter.change(app, [country_filter, industry_filter], [investor_output, graph_output])
157
- industry_filter.change(app, [country_filter, industry_filter], [investor_output, graph_output])
 
 
 
 
 
 
 
158
 
159
  demo.launch()
160
 
 
54
  investor_company_mapping = build_investor_company_mapping(data)
55
  logger.info("Investor to company mapping created.")
56
 
57
+ # Filter investors by country, industry, and investor selection
58
+ def filter_investors(selected_country, selected_industry, selected_investor):
59
  filtered_data = data.copy()
60
  if selected_country != "All":
61
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
62
  if selected_industry != "All":
63
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
64
+ if selected_investor != "All":
65
+ filtered_data = filtered_data[
66
+ filtered_data["Select_Investors"].str.contains(selected_investor, na=False)
67
+ ]
68
 
69
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
70
  filtered_investors = list(investor_company_mapping_filtered.keys())
71
  return filtered_investors, filtered_data
72
 
73
+ # Generate Plotly graph
74
  def generate_graph(investors, filtered_data):
75
  if not investors:
76
  logger.warning("No investors selected.")
 
109
  x, y = pos[node]
110
  node_x.append(x)
111
  node_y.append(y)
112
+ if node in investors:
113
+ node_text.append(node) # Label investors
114
+ node_color.append(20) # Fixed color for investors
115
+ else:
116
+ node_text.append("") # Hide company labels by default
117
+ node_color.append(10)
118
 
119
  node_trace = go.Scatter(
120
  x=node_x,
121
  y=node_y,
122
  text=node_text,
123
+ mode='markers',
124
  hoverinfo='text',
125
  marker=dict(
126
  showscale=False,
127
+ size=15,
128
  color=node_color,
129
+ )
 
130
  )
131
 
132
  fig = go.Figure(data=[edge_trace, node_trace])
 
136
  titlefont_size=20,
137
  margin=dict(l=20, r=20, t=50, b=20),
138
  hovermode='closest',
139
+ width=1200,
140
+ height=800
141
  )
142
  return fig
143
 
144
+ # Gradio app
145
+ def app(selected_country, selected_industry, selected_investor):
146
+ investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investor)
147
  graph = generate_graph(investors, filtered_data)
148
  return investors, graph
149
 
 
151
  def main():
152
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
153
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
154
+ investor_list = ["All"] + sorted(investor_company_mapping.keys())
155
 
156
  with gr.Blocks() as demo:
157
  with gr.Row():
158
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
159
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
160
+ investor_filter = gr.Dropdown(choices=investor_list, label="Investor", value="All")
161
 
162
  investor_output = gr.Textbox(label="Filtered Investors")
163
  graph_output = gr.Plot(label="Network Graph")
164
 
165
+ country_filter.change(
166
+ app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
167
+ )
168
+ industry_filter.change(
169
+ app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
170
+ )
171
+ investor_filter.change(
172
+ app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
173
+ )
174
 
175
  demo.launch()
176