Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -54,19 +54,23 @@ def build_investor_company_mapping(df):
|
|
54 |
investor_company_mapping = build_investor_company_mapping(data)
|
55 |
logger.info("Investor to company mapping created.")
|
56 |
|
57 |
-
# Filter investors by country
|
58 |
-
def filter_investors(selected_country, selected_industry):
|
59 |
filtered_data = data.copy()
|
60 |
if selected_country != "All":
|
61 |
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
|
62 |
if selected_industry != "All":
|
63 |
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
|
|
|
|
|
|
|
|
|
64 |
|
65 |
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
|
66 |
filtered_investors = list(investor_company_mapping_filtered.keys())
|
67 |
return filtered_investors, filtered_data
|
68 |
|
69 |
-
# Generate Plotly graph
|
70 |
def generate_graph(investors, filtered_data):
|
71 |
if not investors:
|
72 |
logger.warning("No investors selected.")
|
@@ -105,21 +109,24 @@ def generate_graph(investors, filtered_data):
|
|
105 |
x, y = pos[node]
|
106 |
node_x.append(x)
|
107 |
node_y.append(y)
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
110 |
|
111 |
node_trace = go.Scatter(
|
112 |
x=node_x,
|
113 |
y=node_y,
|
114 |
text=node_text,
|
115 |
-
mode='markers
|
116 |
hoverinfo='text',
|
117 |
marker=dict(
|
118 |
showscale=False,
|
119 |
-
size=15,
|
120 |
color=node_color,
|
121 |
-
)
|
122 |
-
textposition="top center" # Improved label positioning
|
123 |
)
|
124 |
|
125 |
fig = go.Figure(data=[edge_trace, node_trace])
|
@@ -129,14 +136,14 @@ def generate_graph(investors, filtered_data):
|
|
129 |
titlefont_size=20,
|
130 |
margin=dict(l=20, r=20, t=50, b=20),
|
131 |
hovermode='closest',
|
132 |
-
width=1200,
|
133 |
-
height=800
|
134 |
)
|
135 |
return fig
|
136 |
|
137 |
-
#
|
138 |
-
def app(selected_country, selected_industry):
|
139 |
-
investors, filtered_data = filter_investors(selected_country, selected_industry)
|
140 |
graph = generate_graph(investors, filtered_data)
|
141 |
return investors, graph
|
142 |
|
@@ -144,17 +151,26 @@ def app(selected_country, selected_industry):
|
|
144 |
def main():
|
145 |
country_list = ["All"] + sorted(data["Country"].dropna().unique())
|
146 |
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
|
|
|
147 |
|
148 |
with gr.Blocks() as demo:
|
149 |
with gr.Row():
|
150 |
country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
|
151 |
industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
|
|
|
152 |
|
153 |
investor_output = gr.Textbox(label="Filtered Investors")
|
154 |
graph_output = gr.Plot(label="Network Graph")
|
155 |
|
156 |
-
country_filter.change(
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
demo.launch()
|
160 |
|
|
|
54 |
investor_company_mapping = build_investor_company_mapping(data)
|
55 |
logger.info("Investor to company mapping created.")
|
56 |
|
57 |
+
# Filter investors by country, industry, and investor selection
|
58 |
+
def filter_investors(selected_country, selected_industry, selected_investor):
|
59 |
filtered_data = data.copy()
|
60 |
if selected_country != "All":
|
61 |
filtered_data = filtered_data[filtered_data["Country"] == selected_country]
|
62 |
if selected_industry != "All":
|
63 |
filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
|
64 |
+
if selected_investor != "All":
|
65 |
+
filtered_data = filtered_data[
|
66 |
+
filtered_data["Select_Investors"].str.contains(selected_investor, na=False)
|
67 |
+
]
|
68 |
|
69 |
investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
|
70 |
filtered_investors = list(investor_company_mapping_filtered.keys())
|
71 |
return filtered_investors, filtered_data
|
72 |
|
73 |
+
# Generate Plotly graph
|
74 |
def generate_graph(investors, filtered_data):
|
75 |
if not investors:
|
76 |
logger.warning("No investors selected.")
|
|
|
109 |
x, y = pos[node]
|
110 |
node_x.append(x)
|
111 |
node_y.append(y)
|
112 |
+
if node in investors:
|
113 |
+
node_text.append(node) # Label investors
|
114 |
+
node_color.append(20) # Fixed color for investors
|
115 |
+
else:
|
116 |
+
node_text.append("") # Hide company labels by default
|
117 |
+
node_color.append(10)
|
118 |
|
119 |
node_trace = go.Scatter(
|
120 |
x=node_x,
|
121 |
y=node_y,
|
122 |
text=node_text,
|
123 |
+
mode='markers',
|
124 |
hoverinfo='text',
|
125 |
marker=dict(
|
126 |
showscale=False,
|
127 |
+
size=15,
|
128 |
color=node_color,
|
129 |
+
)
|
|
|
130 |
)
|
131 |
|
132 |
fig = go.Figure(data=[edge_trace, node_trace])
|
|
|
136 |
titlefont_size=20,
|
137 |
margin=dict(l=20, r=20, t=50, b=20),
|
138 |
hovermode='closest',
|
139 |
+
width=1200,
|
140 |
+
height=800
|
141 |
)
|
142 |
return fig
|
143 |
|
144 |
+
# Gradio app
|
145 |
+
def app(selected_country, selected_industry, selected_investor):
|
146 |
+
investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investor)
|
147 |
graph = generate_graph(investors, filtered_data)
|
148 |
return investors, graph
|
149 |
|
|
|
151 |
def main():
|
152 |
country_list = ["All"] + sorted(data["Country"].dropna().unique())
|
153 |
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
|
154 |
+
investor_list = ["All"] + sorted(investor_company_mapping.keys())
|
155 |
|
156 |
with gr.Blocks() as demo:
|
157 |
with gr.Row():
|
158 |
country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
|
159 |
industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
|
160 |
+
investor_filter = gr.Dropdown(choices=investor_list, label="Investor", value="All")
|
161 |
|
162 |
investor_output = gr.Textbox(label="Filtered Investors")
|
163 |
graph_output = gr.Plot(label="Network Graph")
|
164 |
|
165 |
+
country_filter.change(
|
166 |
+
app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
|
167 |
+
)
|
168 |
+
industry_filter.change(
|
169 |
+
app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
|
170 |
+
)
|
171 |
+
investor_filter.change(
|
172 |
+
app, [country_filter, industry_filter, investor_filter], [investor_output, graph_output]
|
173 |
+
)
|
174 |
|
175 |
demo.launch()
|
176 |
|