LeonceNsh commited on
Commit
5240604
·
verified ·
1 Parent(s): 04d2663

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -160
app.py CHANGED
@@ -40,165 +40,69 @@ logger.info(f"Using valuation column: {valuation_column}")
40
 
41
  # Clean and prepare data
42
  data["valuation_billions"] = data[valuation_column].apply(
43
- lambda x: float(re.sub(r"[^\d.]", "", str(x))) if pd.notnull(x) else 0
44
  )
45
- data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
46
-
47
- # Clean string columns
48
- data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
49
- data.rename(columns={
50
- "company": "Company",
51
- "date_joined": "Date_Joined",
52
- "country": "Country",
53
- "city": "City",
54
- "industry": "Industry",
55
- "select_investors": "Select_Investors"
56
- }, inplace=True)
57
-
58
- logger.info("Data cleaned and columns renamed.")
59
-
60
- # Build investor-company mapping
61
- def build_investor_company_mapping(df):
62
- mapping = {}
63
- for _, row in df.iterrows():
64
- company = row["Company"]
65
- investors = row["Select_Investors"]
66
- if pd.notnull(investors):
67
- for investor in investors.split(","):
68
- investor = investor.strip()
69
- if investor:
70
- mapping.setdefault(investor, []).append(company)
71
- return mapping
72
-
73
- investor_company_mapping = build_investor_company_mapping(data)
74
- logger.info("Investor to company mapping created.")
75
-
76
- # -------------------------
77
- # Valuation-Range Logic
78
- # -------------------------
79
- def filter_by_valuation_range(df, selected_valuation_range):
80
- """Filter dataframe by the specified valuation range in billions."""
81
- if selected_valuation_range == "All":
82
- return df # No further filtering
83
-
84
- if selected_valuation_range == "1-5":
85
- return df[(df["valuation_billions"] >= 1) & (df["valuation_billions"] < 5)]
86
- elif selected_valuation_range == "5-10":
87
- return df[(df["valuation_billions"] >= 5) & (df["valuation_billions"] < 10)]
88
- elif selected_valuation_range == "10-15":
89
- return df[(df["valuation_billions"] >= 10) & (df["valuation_billions"] < 15)]
90
- elif selected_valuation_range == "15-20":
91
- return df[(df["valuation_billions"] >= 15) & (df["valuation_billions"] < 20)]
92
- elif selected_valuation_range == "20+":
93
- return df[df["valuation_billions"] >= 20]
94
- else:
95
- return df # Fallback, should never happen
96
-
97
- # Filter investors by country, industry, investor selection, company selection, and valuation range
98
- def filter_investors(
99
- selected_country,
100
- selected_industry,
101
- selected_investors,
102
- selected_company,
103
- exclude_countries,
104
- exclude_industries,
105
- exclude_companies,
106
- selected_valuation_range
107
- ):
108
- filtered_data = data.copy()
109
-
110
- # Apply filters
111
- if selected_country:
112
- filtered_data = filtered_data[filtered_data["Country"] == selected_country]
113
- if selected_industry:
114
- filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
115
- if selected_investors:
116
- filtered_data = filtered_data[filtered_data["Select_Investors"].str.contains(selected_investors, na=False)]
117
- if selected_company:
118
- filtered_data = filtered_data[filtered_data["Company"] == selected_company]
119
- if exclude_countries:
120
- filtered_data = filtered_data[~filtered_data["Country"].isin(exclude_countries)]
121
- if exclude_industries:
122
- filtered_data = filtered_data[~filtered_data["Industry"].isin(exclude_industries)]
123
- if exclude_companies:
124
- filtered_data = filtered_data[~filtered_data["Company"].isin(exclude_companies)]
125
- if selected_valuation_range:
126
- filtered_data = filter_by_valuation_range(filtered_data, selected_valuation_range)
127
-
128
- return filtered_data
129
-
130
- # Create the graph visualization
131
- def create_network_graph(filtered_data):
132
- graph = nx.Graph()
133
-
134
- # Add nodes and edges
135
- for _, row in filtered_data.iterrows():
136
- company = row['Company']
137
- venture_firm = row['Select_Investors']
138
-
139
- # Add company node (green color, size based on valuation)
140
- graph.add_node(company, node_color='green', node_size=row['valuation_billions'] * 10)
141
-
142
- # Add venture firm node (different color, fixed size)
143
- graph.add_node(venture_firm, node_color='blue', node_size=30)
144
-
145
- # Add an edge between the company and the venture firm
146
- graph.add_edge(company, venture_firm)
147
-
148
- # Generate visualization
149
- pos = nx.spring_layout(graph) # Layout for positioning
150
- fig = go.Figure()
151
-
152
- # Add nodes
153
- for node, attrs in graph.nodes(data=True):
154
- fig.add_trace(go.Scatter(
155
- x=[pos[node][0]], y=[pos[node][1]],
156
- mode='markers+text',
157
- marker=dict(size=attrs['node_size'], color=attrs['node_color']),
158
- text=node,
159
- textposition='top center'
160
- ))
161
-
162
- # Add edges
163
- for edge in graph.edges:
164
- x0, y0 = pos[edge[0]]
165
- x1, y1 = pos[edge[1]]
166
- fig.add_trace(go.Scatter(
167
- x=[x0, x1, None], y=[y0, y1, None],
168
- mode='lines',
169
- line=dict(width=1, color='grey')
170
- ))
171
-
172
- fig.update_layout(title="Company-Investor Network", showlegend=False)
173
-
174
- # Add the note to the plot
175
- note = ("Note: All companies are in green while venture firms have different colors. "
176
- "The diameter of the company circle varies proportionate to the valuation of the corresponding company.")
177
- logger.info(note)
178
-
179
- return fig
180
-
181
- # Gradio interface
182
- def display_network(selected_country, selected_industry, selected_investors, selected_company,
183
- exclude_countries, exclude_industries, exclude_companies, selected_valuation_range):
184
- filtered_data = filter_investors(
185
- selected_country, selected_industry, selected_investors, selected_company,
186
- exclude_countries, exclude_industries, exclude_companies, selected_valuation_range
187
  )
188
- return create_network_graph(filtered_data)
189
-
190
- # Gradio interface inputs
191
- inputs = [
192
- gr.Textbox(label="Select Country"),
193
- gr.Textbox(label="Select Industry"),
194
- gr.Textbox(label="Select Investors"),
195
- gr.Textbox(label="Select Company"),
196
- gr.Textbox(label="Exclude Countries"),
197
- gr.Textbox(label="Exclude Industries"),
198
- gr.Textbox(label="Exclude Companies"),
199
- gr.Radio(choices=["All", "1-5", "5-10", "10-15", "15-20", "20+"], label="Valuation Range")
200
- ]
201
-
202
- # Launch the Gradio interface
203
- interface = gr.Interface(fn=display_network, inputs=inputs, outputs="plot", live=True)
204
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Clean and prepare data
42
  data["valuation_billions"] = data[valuation_column].apply(
43
+ lambda x: float(re.sub(r"[^0-9.]", "", str(x))) / 1e9 if pd.notnull(x) else 0
44
  )
45
+ logger.info("Valuation column cleaned and converted to billions.")
46
+
47
+ # Create a graph
48
+ G = nx.Graph()
49
+
50
+ for _, row in data.iterrows():
51
+ company_name = row["company"]
52
+ valuation = row["valuation_billions"]
53
+ industry = row["industry"]
54
+
55
+ # Add company node
56
+ G.add_node(
57
+ company_name,
58
+ size=valuation,
59
+ color="green" if industry == "Venture Firm" else "blue",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
+
62
+ # Add connections based on relationships (assume relationships column exists)
63
+ if "relationships" in data.columns:
64
+ relationships = str(row["relationships"]).split(";")
65
+ for relation in relationships:
66
+ G.add_edge(company_name, relation.strip())
67
+
68
+ # Create Plotly visualization
69
+ node_sizes = [G.nodes[node]["size"] * 50 for node in G.nodes]
70
+ node_colors = [G.nodes[node]["color"] for node in G.nodes]
71
+
72
+ pos = nx.spring_layout(G)
73
+ x_coords = [pos[node][0] for node in G.nodes]
74
+ y_coords = [pos[node][1] for node in G.nodes]
75
+
76
+ fig = go.Figure()
77
+
78
+ fig.add_trace(
79
+ go.Scatter(
80
+ x=x_coords,
81
+ y=y_coords,
82
+ mode="markers+text",
83
+ marker=dict(size=node_sizes, color=node_colors, opacity=0.8),
84
+ text=list(G.nodes),
85
+ textposition="top center",
86
+ )
87
+ )
88
+
89
+ fig.update_layout(
90
+ title="Company Network Visualization",
91
+ xaxis=dict(showgrid=False, zeroline=False),
92
+ yaxis=dict(showgrid=False, zeroline=False),
93
+ showlegend=False,
94
+ )
95
+
96
+ # Note: All companies are in green while venture firms have different colors.
97
+ # The diameter of the company circle varies proportionate to the valuation of the corresponding company.
98
+
99
+ # Create Gradio interface
100
+ def display_network():
101
+ return fig.to_html()
102
+
103
+ gr.Interface(
104
+ fn=display_network,
105
+ inputs=[],
106
+ outputs="html",
107
+ title="Company Network Visualization",
108
+ ).launch()