LeonceNsh commited on
Commit
eedc3a8
·
verified ·
1 Parent(s): 0de2f41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -15
app.py CHANGED
@@ -29,6 +29,7 @@ logger.info(f"Standardized Column Names: {data.columns.tolist()}")
29
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
30
  data.rename(columns={
31
  "company": "Company",
 
32
  "date_joined": "Date_Joined",
33
  "country": "Country",
34
  "city": "City",
@@ -36,6 +37,10 @@ data.rename(columns={
36
  "select_investors": "Select_Investors"
37
  }, inplace=True)
38
 
 
 
 
 
39
  logger.info("Data cleaned and columns renamed.")
40
 
41
  # Build investor-company mapping
@@ -55,15 +60,17 @@ investor_company_mapping = build_investor_company_mapping(data)
55
  logger.info("Investor to company mapping created.")
56
 
57
  # Filter investors by country, industry, and investor selection
58
- def filter_investors(selected_country, selected_industry, selected_investor):
59
  filtered_data = data.copy()
60
  if selected_country != "All":
61
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
62
  if selected_industry != "All":
63
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
64
- if selected_investor != "All":
65
  filtered_data = filtered_data[
66
- filtered_data["Select_Investors"].str.contains(selected_investor, na=False)
 
 
67
  ]
68
 
69
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
@@ -95,36 +102,48 @@ def generate_graph(investors, filtered_data):
95
  edge_trace = go.Scatter(
96
  x=edge_x,
97
  y=edge_y,
98
- line=dict(width=1, color='#888'),
99
- hoverinfo='none',
100
- mode='lines'
101
  )
102
 
103
  node_x = []
104
  node_y = []
105
  node_text = []
106
  node_color = []
 
 
 
 
 
 
 
 
107
 
108
  for node in G.nodes():
109
  x, y = pos[node]
110
  node_x.append(x)
111
  node_y.append(y)
 
112
  if node in investors:
113
  node_text.append(node) # Label investors
114
- node_color.append(20) # Fixed color for investors
 
115
  else:
 
116
  node_text.append("") # Hide company labels by default
117
- node_color.append(10)
 
118
 
119
  node_trace = go.Scatter(
120
  x=node_x,
121
  y=node_y,
122
  text=node_text,
123
- mode='markers',
124
- hoverinfo='text',
125
  marker=dict(
126
  showscale=False,
127
- size=15,
128
  color=node_color,
129
  )
130
  )
@@ -135,15 +154,15 @@ def generate_graph(investors, filtered_data):
135
  title="Venture Networks",
136
  titlefont_size=20,
137
  margin=dict(l=20, r=20, t=50, b=20),
138
- hovermode='closest',
139
  width=1200,
140
  height=800
141
  )
142
  return fig
143
 
144
  # Gradio app
145
- def app(selected_country, selected_industry, selected_investor):
146
- investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investor)
147
  graph = generate_graph(investors, filtered_data)
148
  return investors, graph
149
 
@@ -157,7 +176,7 @@ def main():
157
  with gr.Row():
158
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
159
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
160
- investor_filter = gr.Dropdown(choices=investor_list, label="Investor", value="All")
161
 
162
  investor_output = gr.Textbox(label="Filtered Investors")
163
  graph_output = gr.Plot(label="Network Graph")
 
29
  data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
30
  data.rename(columns={
31
  "company": "Company",
32
+ "valuation": "Valuation",
33
  "date_joined": "Date_Joined",
34
  "country": "Country",
35
  "city": "City",
 
37
  "select_investors": "Select_Investors"
38
  }, inplace=True)
39
 
40
+ # Convert valuation to numeric for proportional node sizing
41
+ data["Valuation"] = pd.to_numeric(
42
+ data["Valuation"].replace({"\$": "", ",": ""}, regex=True), errors="coerce"
43
+ )
44
  logger.info("Data cleaned and columns renamed.")
45
 
46
  # Build investor-company mapping
 
60
  logger.info("Investor to company mapping created.")
61
 
62
  # Filter investors by country, industry, and investor selection
63
+ def filter_investors(selected_country, selected_industry, selected_investors):
64
  filtered_data = data.copy()
65
  if selected_country != "All":
66
  filtered_data = filtered_data[filtered_data["Country"] == selected_country]
67
  if selected_industry != "All":
68
  filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
69
+ if selected_investors != ["All"]:
70
  filtered_data = filtered_data[
71
+ filtered_data["Select_Investors"].apply(
72
+ lambda x: any(inv in x for inv in selected_investors) if pd.notnull(x) else False
73
+ )
74
  ]
75
 
76
  investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
 
102
  edge_trace = go.Scatter(
103
  x=edge_x,
104
  y=edge_y,
105
+ line=dict(width=1, color="#888"),
106
+ hoverinfo="none",
107
+ mode="lines"
108
  )
109
 
110
  node_x = []
111
  node_y = []
112
  node_text = []
113
  node_color = []
114
+ node_size = []
115
+
116
+ # Color palette for investors (color blind friendly)
117
+ investor_colors = [
118
+ "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7"
119
+ ]
120
+
121
+ investor_color_map = {investor: investor_colors[i % len(investor_colors)] for i, investor in enumerate(investors)}
122
 
123
  for node in G.nodes():
124
  x, y = pos[node]
125
  node_x.append(x)
126
  node_y.append(y)
127
+
128
  if node in investors:
129
  node_text.append(node) # Label investors
130
+ node_color.append(investor_color_map[node]) # Assign distinct colors to investors
131
+ node_size.append(20) # Fixed size for investor nodes
132
  else:
133
+ valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation"].sum()
134
  node_text.append("") # Hide company labels by default
135
+ node_color.append("lightgreen") # Light green for companies
136
+ node_size.append(max(10, valuation / 100)) # Size proportional to valuation
137
 
138
  node_trace = go.Scatter(
139
  x=node_x,
140
  y=node_y,
141
  text=node_text,
142
+ mode="markers",
143
+ hoverinfo="text",
144
  marker=dict(
145
  showscale=False,
146
+ size=node_size,
147
  color=node_color,
148
  )
149
  )
 
154
  title="Venture Networks",
155
  titlefont_size=20,
156
  margin=dict(l=20, r=20, t=50, b=20),
157
+ hovermode="closest",
158
  width=1200,
159
  height=800
160
  )
161
  return fig
162
 
163
  # Gradio app
164
+ def app(selected_country, selected_industry, selected_investors):
165
+ investors, filtered_data = filter_investors(selected_country, selected_industry, selected_investors)
166
  graph = generate_graph(investors, filtered_data)
167
  return investors, graph
168
 
 
176
  with gr.Row():
177
  country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
178
  industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
179
+ investor_filter = gr.CheckboxGroup(choices=investor_list, label="Investors", value=["All"])
180
 
181
  investor_output = gr.Textbox(label="Filtered Investors")
182
  graph_output = gr.Plot(label="Network Graph")