LeonceNsh commited on
Commit
2f36052
·
verified ·
1 Parent(s): 75a7b9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py CHANGED
@@ -1,3 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Gradio app function to update CheckboxGroup and filtered data
2
  def app(selected_country, selected_industry):
3
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)
 
1
+ import pandas as pd
2
+ import networkx as nx
3
+ import plotly.graph_objects as go
4
+ import gradio as gr
5
+ import logging
6
+
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Load and preprocess the dataset
12
+ file_path = "cbinsights_data.csv" # Replace with your actual file path
13
+
14
+ try:
15
+ data = pd.read_csv(file_path, skiprows=1)
16
+ logger.info("CSV file loaded successfully.")
17
+ except FileNotFoundError:
18
+ logger.error(f"File not found: {file_path}")
19
+ raise
20
+ except Exception as e:
21
+ logger.error(f"Error loading CSV file: {e}")
22
+ raise
23
+
24
+ # Standardize column names: strip whitespace and convert to lowercase
25
+ data.columns = data.columns.str.strip().str.lower()
26
+ logger.info(f"Standardized Column Names: {data.columns.tolist()}")
27
+
28
+ # Identify the valuation column dynamically
29
+ valuation_columns = [col for col in data.columns if 'valuation' in col.lower()]
30
+ if not valuation_columns:
31
+ logger.error("No column containing 'Valuation' found in the dataset.")
32
+ raise ValueError("Data Error: Unable to find the valuation column. Please check your CSV file.")
33
+ elif len(valuation_columns) > 1:
34
+ logger.error("Multiple columns containing 'Valuation' found in the dataset.")
35
+ raise ValueError("Data Error: Multiple valuation columns detected. Please ensure only one valuation column exists.")
36
+ else:
37
+ valuation_column = valuation_columns[0]
38
+ logger.info(f"Using valuation column: {valuation_column}")
39
+
40
+ # Clean and prepare data
41
+ data["valuation_billions"] = data[valuation_column].replace({'\$': '', ',': ''}, regex=True)
42
+ data["valuation_billions"] = pd.to_numeric(data["valuation_billions"], errors='coerce')
43
+ logger.info("Valuation data cleaned and converted to numeric.")
44
+
45
+ # Strip whitespace from all string columns
46
+ data = data.apply(lambda col: col.str.strip() if col.dtype == "object" else col)
47
+ logger.info("Whitespace stripped from all string columns.")
48
+
49
+ # Rename columns for consistency
50
+ expected_columns = {
51
+ "company": "Company",
52
+ "valuation_billions": "Valuation_Billions",
53
+ "date_joined": "Date_Joined",
54
+ "country": "Country",
55
+ "city": "City",
56
+ "industry": "Industry",
57
+ "select_investors": "Select_Investors"
58
+ }
59
+
60
+ missing_columns = set(expected_columns.keys()) - set(data.columns)
61
+ if missing_columns:
62
+ logger.error(f"Missing columns in the dataset: {missing_columns}")
63
+ raise ValueError(f"Data Error: Missing columns {missing_columns} in the dataset.")
64
+
65
+ data = data.rename(columns=expected_columns)
66
+ logger.info("Columns renamed for consistency.")
67
+
68
+ # Parse the "Select_Investors" column to map investors to companies
69
+ def build_investor_company_mapping(df):
70
+ mapping = {}
71
+ for _, row in df.iterrows():
72
+ company = row["Company"]
73
+ investors = row["Select_Investors"]
74
+ if pd.notnull(investors):
75
+ for investor in investors.split(","):
76
+ investor = investor.strip()
77
+ if investor: # Ensure investor is not an empty string
78
+ mapping.setdefault(investor, []).append(company)
79
+ return mapping
80
+
81
+ investor_company_mapping = build_investor_company_mapping(data)
82
+ logger.info("Investor to company mapping created.")
83
+
84
+ # Function to filter investors based on selected country and industry
85
+ def filter_investors_by_country_and_industry(selected_country, selected_industry):
86
+ filtered_data = data.copy()
87
+ logger.info(f"Filtering data for Country: {selected_country}, Industry: {selected_industry}")
88
+
89
+ if selected_country != "All":
90
+ filtered_data = filtered_data[filtered_data["Country"] == selected_country]
91
+ logger.info(f"Data filtered by country: {selected_country}. Remaining records: {len(filtered_data)}")
92
+ if selected_industry != "All":
93
+ filtered_data = filtered_data[filtered_data["Industry"] == selected_industry]
94
+ logger.info(f"Data filtered by industry: {selected_industry}. Remaining records: {len(filtered_data)}")
95
+
96
+ investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
97
+
98
+ # Calculate total valuation per investor
99
+ investor_valuations = {}
100
+ for investor, companies in investor_company_mapping_filtered.items():
101
+ total_valuation = filtered_data[filtered_data["Company"].isin(companies)]["Valuation_Billions"].sum()
102
+ if total_valuation >= 20: # Investors with >= 20B total valuation
103
+ investor_valuations[investor] = total_valuation
104
+
105
+ logger.info(f"Filtered investors with total valuation >= 20B: {len(investor_valuations)}")
106
+
107
+ return list(investor_valuations.keys()), filtered_data
108
+
109
+ # Function to generate the Plotly graph
110
+ def generate_graph(selected_investors, filtered_data):
111
+ if not selected_investors:
112
+ logger.warning("No investors selected. Returning empty figure.")
113
+ return go.Figure()
114
+
115
+ investor_company_mapping_filtered = build_investor_company_mapping(filtered_data)
116
+ filtered_mapping = {inv: investor_company_mapping_filtered[inv] for inv in selected_investors if inv in investor_company_mapping_filtered}
117
+
118
+ logger.info(f"Generating graph for {len(filtered_mapping)} investors.")
119
+
120
+ # Build the graph
121
+ G = nx.Graph()
122
+ for investor, companies in filtered_mapping.items():
123
+ for company in companies:
124
+ G.add_edge(investor, company)
125
+
126
+ # Generate positions using spring layout
127
+ pos = nx.spring_layout(G, k=0.2, seed=42)
128
+
129
+ # Prepare Plotly traces
130
+ edge_x = []
131
+ edge_y = []
132
+ for edge in G.edges():
133
+ x0, y0 = pos[edge[0]]
134
+ x1, y1 = pos[edge[1]]
135
+ edge_x += [x0, x1, None]
136
+ edge_y += [y0, y1, None]
137
+
138
+ edge_trace = go.Scatter(
139
+ x=edge_x, y=edge_y,
140
+ line=dict(width=0.5, color='#888'),
141
+ hoverinfo='none',
142
+ mode='lines'
143
+ )
144
+
145
+ node_x = []
146
+ node_y = []
147
+ node_text = []
148
+ node_size = []
149
+ node_color = []
150
+ customdata = []
151
+ for node in G.nodes():
152
+ x, y = pos[node]
153
+ node_x.append(x)
154
+ node_y.append(y)
155
+ if node in filtered_mapping:
156
+ node_text.append(f"Investor: {node}")
157
+ node_size.append(20) # Investors have larger size
158
+ node_color.append('orange')
159
+ customdata.append(None) # Investors do not have a single valuation
160
+ else:
161
+ valuation = filtered_data.loc[filtered_data["Company"] == node, "Valuation_Billions"].sum()
162
+ node_text.append(f"Company: {node}<br>Valuation: ${valuation}B")
163
+ node_size.append(10 + (valuation / filtered_data["Valuation_Billions"].max()) * 30 if filtered_data["Valuation_Billions"].max() else 10)
164
+ node_color.append('green')
165
+ customdata.append(f"${valuation}B")
166
+
167
+ node_trace = go.Scatter(
168
+ x=node_x, y=node_y,
169
+ mode='markers',
170
+ hoverinfo='text',
171
+ text=node_text,
172
+ customdata=customdata,
173
+ marker=dict(
174
+ showscale=False,
175
+ colorscale='YlGnBu',
176
+ color=node_color,
177
+ size=node_size,
178
+ line_width=2
179
+ )
180
+ )
181
+
182
+ fig = go.Figure(data=[edge_trace, node_trace],
183
+ layout=go.Layout(
184
+ title='Venture Network Visualization',
185
+ titlefont_size=16,
186
+ showlegend=False,
187
+ hovermode='closest',
188
+ margin=dict(b=20,l=5,r=5,t=40),
189
+ annotations=[ dict(
190
+ text="",
191
+ showarrow=False,
192
+ xref="paper", yref="paper") ],
193
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
194
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
195
+ )
196
+
197
+ fig.update_traces(marker=dict(line=dict(width=0.5, color='white')), selector=dict(mode='markers'))
198
+
199
+ logger.info("Plotly graph generated successfully.")
200
+
201
+ return fig
202
+
203
  # Gradio app function to update CheckboxGroup and filtered data
204
  def app(selected_country, selected_industry):
205
  investor_list, filtered_data = filter_investors_by_country_and_industry(selected_country, selected_industry)