LeonceNsh commited on
Commit
970f3bc
·
verified ·
1 Parent(s): 1e9f79e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -9
app.py CHANGED
@@ -106,17 +106,63 @@ def filter_investors_by_country_and_industry(selected_country, selected_industry
106
 
107
  return list(investor_valuations.keys()), filtered_data
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # Gradio app function
110
  def app(selected_country, selected_industry, valuation_threshold):
111
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry, valuation_threshold)
112
-
113
- if not investor_list:
114
- return (
115
- "No investors meet the selected criteria. Try reducing the valuation threshold or selecting different filters.",
116
- None
117
- )
118
-
119
- return investor_list, generate_graph(investor_list, filtered_data)
120
 
121
  # Gradio Interface
122
  def main():
@@ -132,8 +178,8 @@ def main():
132
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
133
  valuation_threshold = gr.Slider(minimum=0, maximum=50, step=1, value=20, label="Valuation Threshold (in B)")
134
 
135
- graph_output = gr.Plot(label="Venture Network Graph")
136
  investor_output = gr.Text(label="Investor Results")
 
137
 
138
  country_filter.change(
139
  app,
 
106
 
107
  return list(investor_valuations.keys()), filtered_data
108
 
109
+ # Function to generate the Plotly graph
110
+ def generate_graph(investor_list, filtered_data):
111
+ if not investor_list:
112
+ logger.warning("No investors selected. Returning empty figure.")
113
+ return go.Figure()
114
+
115
+ G = nx.Graph()
116
+ for investor in investor_list:
117
+ companies = filtered_data[filtered_data["Select_Investors"].str.contains(investor, na=False)]["Company"].tolist()
118
+ for company in companies:
119
+ G.add_edge(investor, company)
120
+
121
+ pos = nx.spring_layout(G, k=0.2, seed=42)
122
+
123
+ # Create Plotly traces for edges and nodes
124
+ edge_trace = go.Scatter(
125
+ x=[],
126
+ y=[],
127
+ line=dict(width=0.5, color='#888'),
128
+ hoverinfo='none',
129
+ mode='lines'
130
+ )
131
+
132
+ for edge in G.edges():
133
+ x0, y0 = pos[edge[0]]
134
+ x1, y1 = pos[edge[1]]
135
+ edge_trace['x'] += [x0, x1, None]
136
+ edge_trace['y'] += [y0, y1, None]
137
+
138
+ node_trace = go.Scatter(
139
+ x=[],
140
+ y=[],
141
+ text=[],
142
+ mode='markers',
143
+ hoverinfo='text',
144
+ marker=dict(
145
+ showscale=True,
146
+ colorscale='YlGnBu',
147
+ size=10,
148
+ colorbar=dict(thickness=15, title='Node Valuation')
149
+ )
150
+ )
151
+
152
+ for node in G.nodes():
153
+ x, y = pos[node]
154
+ node_trace['x'] += [x]
155
+ node_trace['y'] += [y]
156
+ node_trace['text'] += [f"{node}"]
157
+
158
+ fig = go.Figure(data=[edge_trace, node_trace])
159
+ return fig
160
+
161
  # Gradio app function
162
  def app(selected_country, selected_industry, valuation_threshold):
163
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry, valuation_threshold)
164
+ graph = generate_graph(investor_list, filtered_data)
165
+ return investor_list, graph
 
 
 
 
 
 
166
 
167
  # Gradio Interface
168
  def main():
 
178
  industry_filter = gr.Dropdown(choices=industry_list, label="Filter by Industry", value="All")
179
  valuation_threshold = gr.Slider(minimum=0, maximum=50, step=1, value=20, label="Valuation Threshold (in B)")
180
 
 
181
  investor_output = gr.Text(label="Investor Results")
182
+ graph_output = gr.Plot(label="Venture Network Graph")
183
 
184
  country_filter.change(
185
  app,