angusfung commited on
Commit
6c3a56c
·
verified ·
1 Parent(s): 01783ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -6
app.py CHANGED
@@ -1,3 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import torch
@@ -49,6 +60,18 @@ device = None
49
 
50
  @asynccontextmanager
51
  async def lifespan(app: FastAPI):
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Load resources on startup
53
  global model, explainer, processor, device
54
 
@@ -125,6 +148,12 @@ app.add_middleware(
125
 
126
  @app.get("/")
127
  async def root():
 
 
 
 
 
 
128
  return {
129
  "message": "Kickstarter Success Prediction API",
130
  "description": "Send a POST request to /predict with campaign data to get a prediction"
@@ -132,6 +161,24 @@ async def root():
132
 
133
  @app.post("/predict")
134
  async def predict(request: Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  try:
136
  # Parse the incoming JSON data
137
  logger.info("Received prediction request")
@@ -193,28 +240,114 @@ async def predict(request: Request):
193
  raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
194
 
195
  def preprocess_raw_data(campaign_data):
196
- """Preprocess raw data using CampaignProcessor"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  try:
198
  # Process the single campaign
199
  logger.info("Processing campaign with CampaignProcessor...")
200
- processed_data = processor.process_campaign(campaign_data, idx=0)
201
 
202
- # Preserve existing numerical values from input if present
203
- for field in NUMERICAL_FIELDS:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  if field in campaign_data:
205
  processed_data[field] = campaign_data[field]
206
  logger.info(f"Using provided value for {field}: {campaign_data[field]}")
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  return processed_data
209
 
210
  except Exception as e:
211
  logger.error(f"Error preprocessing raw data: {str(e)}", exc_info=True)
212
  raise Exception(f"Error preprocessing raw data: {str(e)}")
213
 
214
- # Debugging endpoint to check the environment and loaded resources
215
  @app.get("/debug")
216
  async def debug():
217
- """Endpoint for checking the status of the API and its components"""
 
 
 
 
 
 
 
 
218
  global model, explainer, processor, device
219
 
220
  # Check internet connectivity
 
1
+ """
2
+ Kickstarter Success Prediction API
3
+
4
+ This module serves as the main FastAPI application for the Kickstarter Success Prediction service.
5
+ It provides endpoints for predicting the success probability of Kickstarter campaigns and
6
+ includes the Longformer embedding in the response for further analysis.
7
+
8
+ Author: Angus Fung
9
+ Date: April 2025
10
+ """
11
+
12
  import os
13
  import json
14
  import torch
 
60
 
61
  @asynccontextmanager
62
  async def lifespan(app: FastAPI):
63
+ """
64
+ Lifecycle manager for the FastAPI application.
65
+
66
+ This function handles the startup and shutdown of the application,
67
+ managing resources like model loading and caching directories.
68
+
69
+ Args:
70
+ app: The FastAPI application instance
71
+
72
+ Yields:
73
+ None: Control is yielded back to the application while it's running
74
+ """
75
  # Load resources on startup
76
  global model, explainer, processor, device
77
 
 
148
 
149
  @app.get("/")
150
  async def root():
151
+ """
152
+ Root endpoint providing API information.
153
+
154
+ Returns:
155
+ dict: Basic API information and usage instructions
156
+ """
157
  return {
158
  "message": "Kickstarter Success Prediction API",
159
  "description": "Send a POST request to /predict with campaign data to get a prediction"
 
161
 
162
  @app.post("/predict")
163
  async def predict(request: Request):
164
+ """
165
+ Prediction endpoint for Kickstarter campaign success.
166
+
167
+ This endpoint processes campaign data and returns:
168
+ - Success probability
169
+ - Predicted outcome (Success/Failure)
170
+ - SHAP values for feature importance explanation
171
+ - Longformer embedding of the campaign description
172
+
173
+ Args:
174
+ request: FastAPI request object containing campaign data as JSON
175
+
176
+ Returns:
177
+ JSONResponse: Prediction results and explanations
178
+
179
+ Raises:
180
+ HTTPException: If an error occurs during prediction
181
+ """
182
  try:
183
  # Parse the incoming JSON data
184
  logger.info("Received prediction request")
 
240
  raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
241
 
242
  def preprocess_raw_data(campaign_data):
243
+ """
244
+ Preprocess raw campaign data using CampaignProcessor.
245
+
246
+ This function transforms raw text and numerical campaign data into
247
+ the format required by the prediction model, including:
248
+ - Text embeddings generation for description, blurb, and risks
249
+ - Logarithmic transformation of monetary values (funding goals, pledged amounts)
250
+ - Country name standardization (conversion to ISO alpha-2 codes)
251
+ - Category and country encoding
252
+ - Extraction and normalization of numerical features
253
+
254
+ Args:
255
+ campaign_data (dict): Raw campaign data with text and numerical features
256
+
257
+ Returns:
258
+ dict: Processed data with embeddings and normalized numerical features
259
+
260
+ Raises:
261
+ Exception: If preprocessing fails
262
+ """
263
  try:
264
  # Process the single campaign
265
  logger.info("Processing campaign with CampaignProcessor...")
 
266
 
267
+ # Log country conversion if present
268
+ if 'raw_country' in campaign_data:
269
+ country_name = campaign_data.get('raw_country', '')
270
+ if country_name:
271
+ logger.info(f"Found country in input data: '{country_name}' (will be converted to ISO alpha-2 code)")
272
+
273
+ # Map field names to the expected structure for the processor
274
+ # Make a deep copy to avoid modifying the original
275
+ import copy
276
+ prepared_data = copy.deepcopy(campaign_data)
277
+
278
+ # Log input values for debugging
279
+ logger.info(f"Input previous_projects_count: {prepared_data.get('previous_projects_count', 'N/A')}")
280
+ logger.info(f"Input previous_success_rate: {prepared_data.get('previous_success_rate', 'N/A')}")
281
+ logger.info(f"Input previous_pledged: {prepared_data.get('previous_pledged', 'N/A')}")
282
+ logger.info(f"Input previous_funding_goal: {prepared_data.get('previous_funding_goal', 'N/A')}")
283
+
284
+ # Special handling for success rate calculation
285
+ if 'previous_success_rate' in campaign_data and 'previous_projects_count' in campaign_data:
286
+ success_rate = float(campaign_data['previous_success_rate'])
287
+ projects_count = int(campaign_data['previous_projects_count'])
288
+ # Calculate successful projects from rate and count
289
+ if projects_count > 0:
290
+ prepared_data['previous_successful_projects'] = round(success_rate * projects_count)
291
+ logger.info(f"Calculated previous_successful_projects: {prepared_data['previous_successful_projects']} " +
292
+ f"from success rate: {success_rate} and count: {projects_count}")
293
+
294
+ # Now process the prepared data
295
+ processed_data = processor.process_campaign(prepared_data, idx=0)
296
+
297
+ # SELECTIVE OVERRIDE: Only override non-transformed numeric fields
298
+ # Fields that should NOT undergo logarithmic transformation
299
+ non_transformed_fields = [
300
+ 'description_length', 'image_count', 'video_count',
301
+ 'campaign_duration', 'previous_projects_count', 'previous_success_rate'
302
+ ]
303
+
304
+ # Fields that SHOULD undergo logarithmic transformation
305
+ transformed_fields = [
306
+ 'funding_goal', 'previous_funding_goal', 'previous_pledged'
307
+ ]
308
+
309
+ # Override only the non-transformed fields if they exist in input
310
+ for field in non_transformed_fields:
311
  if field in campaign_data:
312
  processed_data[field] = campaign_data[field]
313
  logger.info(f"Using provided value for {field}: {campaign_data[field]}")
314
 
315
+ # For transformed fields, check if the user explicitly wants to bypass transformation
316
+ for field in transformed_fields:
317
+ if field in campaign_data and campaign_data.get('bypass_transformation', False):
318
+ processed_data[field] = campaign_data[field]
319
+ logger.warning(
320
+ f"Bypassing logarithmic transformation for {field} as requested. "
321
+ "This may affect model performance."
322
+ )
323
+ elif field in campaign_data:
324
+ # Log that we're keeping the transformed value
325
+ logger.info(f"Using logarithmically transformed {field} value for better model performance.")
326
+
327
+ # Verify that the previous metrics are set correctly
328
+ logger.info(f"Final previous_projects_count: {processed_data.get('previous_projects_count', 'N/A')}")
329
+ logger.info(f"Final previous_success_rate: {processed_data.get('previous_success_rate', 'N/A')}")
330
+ logger.info(f"Final previous_pledged: {processed_data.get('previous_pledged', 'N/A')}")
331
+ logger.info(f"Final previous_funding_goal: {processed_data.get('previous_funding_goal', 'N/A')}")
332
+
333
+ logger.info("Preprocessing complete with numerical transformations applied")
334
  return processed_data
335
 
336
  except Exception as e:
337
  logger.error(f"Error preprocessing raw data: {str(e)}", exc_info=True)
338
  raise Exception(f"Error preprocessing raw data: {str(e)}")
339
 
 
340
  @app.get("/debug")
341
  async def debug():
342
+ """
343
+ Debug endpoint for checking API status and component health.
344
+
345
+ This endpoint provides diagnostic information about the API's status,
346
+ model loading, connectivity, disk space, and other components.
347
+
348
+ Returns:
349
+ JSONResponse: Comprehensive diagnostic information
350
+ """
351
  global model, explainer, processor, device
352
 
353
  # Check internet connectivity