lyimo commited on
Commit
f680ba4
·
verified ·
1 Parent(s): a04c319

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -67
app.py CHANGED
@@ -16,146 +16,176 @@ from fastcore.utils import * # For Upload class
16
 
17
  # --- Configuration ---
18
  # Ensure the path to your exported model is correct
19
- # When deploying to HF Spaces, this relative path should work if model.pkl is in the same directory
20
  MODEL_PATH = Path(__file__).parent / 'model.pkl'
21
- # Set device (CPU is usually the default/safest for HF free tier)
 
22
  defaults.device = torch.device('cpu')
23
 
24
  # --- Load Fastai Learner ---
 
25
  try:
26
- print(f"Loading model from: {MODEL_PATH}")
 
 
27
  learn = load_learner(MODEL_PATH)
28
  print("Model loaded successfully.")
29
  # Get class names (vocab) from the learner's dataloaders
30
  CLASS_NAMES = learn.dls.vocab
31
  print(f"Model Classes: {CLASS_NAMES}")
32
- except FileNotFoundError:
33
- print(f"Error: Model file not found at {MODEL_PATH}")
34
- print("Please make sure 'model.pkl' is in the same directory or update MODEL_PATH.")
35
- # In a deployed environment, you might want to raise the error or handle it differently
36
- # For now, we exit if the model isn't found on startup.
37
- raise SystemExit(f"Error: Model file not found at {MODEL_PATH}")
38
  except Exception as e:
39
- print(f"Error loading the model: {e}")
40
- raise SystemExit(f"Error loading the model: {e}")
 
41
 
42
  # --- FastHTML App Setup ---
43
- # FastHTML automatically finds this 'app' object when run with uvicorn app:app
44
  app = FastHTML()
45
- rt = app.route # Route decorator
46
 
47
  # --- Helper Function for Prediction ---
48
  def predict_image(img_bytes: bytes):
49
  """Takes image bytes, predicts using the fastai model."""
 
 
50
  try:
 
51
  img = PILImage.create(img_bytes)
 
52
  pred_class, pred_idx, probs = learn.predict(img)
53
- confidence = probs[pred_idx].item() # Get the probability of the predicted class
 
54
  return pred_class, confidence
55
  except Exception as e:
56
  print(f"Error during prediction: {e}")
57
  # Return a user-friendly error message and neutral confidence
58
- return f"Prediction Error: {e}", 0.0
59
 
60
  # --- Define Routes ---
61
 
62
  @rt("/")
63
  async def get(request):
64
  """Serves the main page with the upload form."""
65
- # Using Bootstrap classes for basic styling
66
  return Titled("Fastai Image Classifier",
67
  Main(cls="container mt-4",
68
  H1("Upload an Image for Classification"),
69
- # Form for uploading the image
 
70
  Form(
 
71
  Div(cls="mb-3",
72
- # Label("Choose Image", fr="fileInput", cls="form-label"), # Optional label
73
- FileInput(name="file", id="fileInput", cls="form-control", required=True), # Added required
74
  ),
75
  Button("Classify Image", type="submit", cls="btn btn-primary"), # Submit button
76
- # HTMX attributes for form submission
77
- hx_post="/predict", # POST request to /predict
78
- hx_target="#results", # Put the response into the #results div
79
- hx_swap="innerHTML", # Replace the content of #results
80
- hx_encoding="multipart/form-data", # Needed for file uploads
81
- # Add indicator for user feedback during processing
82
- hx_indicator="#loading-spinner",
83
- id="upload-form"
84
- ),
85
- # Loading indicator (hidden by default)
86
- Div(id="loading-spinner", cls="htmx-indicator spinner-border mt-3", role="status",
87
- Span("Loading...", cls="visually-hidden")
 
 
88
  ),
89
- # Div where results will be displayed
 
90
  Div(id="results", cls="mt-4")
91
- )
92
- )
93
 
94
  @rt("/predict", methods=["POST"])
95
  async def post(request, file: Upload):
96
- """Handles image upload, prediction, and returns results."""
 
97
  if not file or not file.filename:
98
- return P("No file uploaded. Please select a file.", cls="alert alert-warning")
 
99
 
100
- # Check if it's likely an image file (basic check)
101
  allowed_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'}
102
  file_ext = Path(file.filename).suffix.lower()
103
  if file_ext not in allowed_extensions:
104
- return P(f"Invalid file type: {file_ext}. Please upload an image (png, jpg, jpeg, gif, bmp, webp).", cls="alert alert-danger")
 
105
 
106
- print(f"Received file: {file.filename}, Content-Type: {file.content_type}")
107
 
108
- # Read image bytes
109
  try:
110
- img_bytes = await file.read() # Use await for async file reading
111
  if not img_bytes:
112
- return P("Uploaded file is empty.", cls="alert alert-warning")
 
113
  except Exception as e:
114
- print(f"Error reading file: {e}")
115
- return P(f"Error reading uploaded file: {e}", cls="alert alert-danger")
 
116
 
117
- # Perform prediction
118
  prediction, confidence = predict_image(img_bytes)
119
 
120
- # Encode image to base64 to display it back
 
121
  img_src = None
122
- if "Error" not in prediction: # Only try to display image if prediction didn't fail critically
123
  try:
124
  img_base64 = base64.b64encode(img_bytes).decode('utf-8')
125
- # Try to use the provided content type, default if necessary
126
  content_type = file.content_type if file.content_type and file.content_type.startswith('image/') else 'image/jpeg'
127
  img_src = f"data:{content_type};base64,{img_base64}"
128
  except Exception as e:
129
  print(f"Error encoding image to base64: {e}")
130
- # Don't display image if encoding fails, but still show prediction
131
 
132
- # Return the results as HTML fragment
133
- # Using Bootstrap alert classes for results
134
- result_cls = "alert alert-danger" if "Error" in prediction else "alert alert-success"
135
 
 
 
136
  return Div(
137
- (Img(src=img_src, alt="Uploaded Image", style="max-width: 300px; max-height: 300px; margin-top: 15px; margin-bottom: 10px; display: block;") if img_src else P("Preview not available.")),
 
 
138
  Div(cls=f"{result_cls} mt-3", role="alert",
139
  P(Strong("Prediction: "), f"{prediction}"),
140
- P(Strong("Confidence: "), f"{confidence:.4f}") if "Error" not in prediction else ""
141
- )
 
 
 
 
 
142
  )
143
 
144
- # --- Add CSS/JS ---
145
- # Add Bootstrap CSS and JS for styling and components (like the spinner)
146
- app.hdrs.append(
147
- Script(src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js", integrity="sha384-C6RzsynM9kWDrMNeT87bh95OGNyZPhcTNXj1NW7RuBCsyN/o0jlpcV8Qyq46cDfL", crossorigin="anonymous"),
148
- )
149
  app.sheets.append(
150
  Link(href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css", rel="stylesheet", integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN", crossorigin="anonymous")
151
  )
152
- # Add HTMX itself (FastHTML includes a minimal version, but sometimes the full CDN is useful)
153
- # app.hdrs.append(Script(src="https://unpkg.com/htmx.org@1.9.10/dist/htmx.min.js"))
154
-
155
 
156
- # --- Run the App (for local testing) ---
157
- # This part is mainly for running locally with `python app.py`
158
- # When deployed on Hugging Face Spaces, Uvicorn is run automatically based on the README config
159
  if __name__ == "__main__":
160
- # Use port 8000 which is often standard for these deployments
 
 
 
161
  uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
 
16
 
17
  # --- Configuration ---
18
  # Ensure the path to your exported model is correct
19
+ # When deploying to HF Spaces, this relative path works if model.pkl is in the same directory
20
  MODEL_PATH = Path(__file__).parent / 'model.pkl'
21
+ # Set device (CPU is often the default/safest for HF free tier)
22
+ # Use 'cuda' if you have a GPU and want to use it: defaults_device(use_cuda=True)
23
  defaults.device = torch.device('cpu')
24
 
25
  # --- Load Fastai Learner ---
26
+ # Load the model once when the application starts
27
  try:
28
+ print(f"Attempting to load model from: {MODEL_PATH.resolve()}") # Use resolve() for absolute path in logs
29
+ if not MODEL_PATH.is_file():
30
+ raise FileNotFoundError(f"Model file not found at calculated path: {MODEL_PATH.resolve()}")
31
  learn = load_learner(MODEL_PATH)
32
  print("Model loaded successfully.")
33
  # Get class names (vocab) from the learner's dataloaders
34
  CLASS_NAMES = learn.dls.vocab
35
  print(f"Model Classes: {CLASS_NAMES}")
36
+ except FileNotFoundError as e:
37
+ print(f"Error: {e}")
38
+ print("Please make sure 'model.pkl' is in the same directory as app.py.")
39
+ # Exit if model loading fails, as the app cannot function
40
+ raise SystemExit(f"CRITICAL ERROR: Model file not found at {MODEL_PATH}. Application cannot start.")
 
41
  except Exception as e:
42
+ print(f"CRITICAL ERROR: An unexpected error occurred loading the model: {e}")
43
+ # Exit for any other critical model loading error
44
+ raise SystemExit(f"CRITICAL ERROR: Failed to load model. Application cannot start. Error: {e}")
45
 
46
  # --- FastHTML App Setup ---
47
+ # FastHTML/Uvicorn will automatically find this 'app' object when run via 'uvicorn app:app'
48
  app = FastHTML()
49
+ rt = app.route # Shortcut for the route decorator
50
 
51
  # --- Helper Function for Prediction ---
52
  def predict_image(img_bytes: bytes):
53
  """Takes image bytes, predicts using the fastai model."""
54
+ if not img_bytes:
55
+ return "Error: Image data is empty", 0.0
56
  try:
57
+ # Create PILImage from bytes
58
  img = PILImage.create(img_bytes)
59
+ # Get prediction from the learner
60
  pred_class, pred_idx, probs = learn.predict(img)
61
+ # Get the confidence score for the predicted class
62
+ confidence = probs[pred_idx].item()
63
  return pred_class, confidence
64
  except Exception as e:
65
  print(f"Error during prediction: {e}")
66
  # Return a user-friendly error message and neutral confidence
67
+ return f"Prediction Error: Could not process image ({e})", 0.0
68
 
69
  # --- Define Routes ---
70
 
71
  @rt("/")
72
  async def get(request):
73
  """Serves the main page with the upload form."""
74
+ # Using Bootstrap classes for basic styling and layout
75
  return Titled("Fastai Image Classifier",
76
  Main(cls="container mt-4",
77
  H1("Upload an Image for Classification"),
78
+ # --- Form for uploading the image ---
79
+ # Arguments MUST be ordered: Positional arguments (content) first, then Keyword arguments (attributes)
80
  Form(
81
+ # --- Positional Arguments (Form Content) ---
82
  Div(cls="mb-3",
83
+ # File input element
84
+ FileInput(name="file", id="fileInput", cls="form-control", required=True, accept="image/*"), # Added required and accept attributes
85
  ),
86
  Button("Classify Image", type="submit", cls="btn btn-primary"), # Submit button
87
+
88
+ # --- Keyword Arguments (Form Attributes) ---
89
+ # HTMX attributes for handling the submission
90
+ hx_post="/predict", # Send POST request to /predict endpoint
91
+ hx_target="#results", # Put the response HTML into the div with id="results"
92
+ hx_swap="innerHTML", # Replace the entire content of the target div
93
+ hx_encoding="multipart/form-data", # Necessary for file uploads
94
+ hx_indicator="#loading-spinner", # Show the element with id="loading-spinner" during the request
95
+ id="upload-form" # Standard HTML id for the form
96
+ ), # End of Form component arguments
97
+ # --- Loading Indicator ---
98
+ # This div is shown by hx-indicator during the HTMX request
99
+ Div(id="loading-spinner", cls="htmx-indicator spinner-border mt-3", role="status", style="display: none;", # Initially hidden
100
+ Span("Loading...", cls="visually-hidden") # Accessibility text for the spinner
101
  ),
102
+ # --- Results Area ---
103
+ # This div is targeted by hx-target to display the prediction results
104
  Div(id="results", cls="mt-4")
105
+ ) # End of Main component
106
+ ) # End of Titled component
107
 
108
  @rt("/predict", methods=["POST"])
109
  async def post(request, file: Upload):
110
+ """Handles image upload, performs prediction, and returns results as an HTML fragment."""
111
+ # --- Input Validation ---
112
  if not file or not file.filename:
113
+ # Return an error message if no file is received
114
+ return Div(P("No file uploaded. Please select an image file.", cls="alert alert-warning mt-3"), id="results") # Ensure id matches target
115
 
116
+ # Basic check for allowed image file extensions
117
  allowed_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'}
118
  file_ext = Path(file.filename).suffix.lower()
119
  if file_ext not in allowed_extensions:
120
+ # Return an error message for invalid file types
121
+ return Div(P(f"Invalid file type: '{file_ext}'. Please upload an image ({', '.join(allowed_extensions)}).", cls="alert alert-danger mt-3"), id="results") # Ensure id matches target
122
 
123
+ print(f"Received file: {file.filename}, Content-Type: {file.content_type}, Size: {file.size}")
124
 
125
+ # --- Read Image Data ---
126
  try:
127
+ img_bytes = await file.read() # Read the file content asynchronously
128
  if not img_bytes:
129
+ # Handle empty file upload
130
+ return Div(P("Uploaded file appears to be empty.", cls="alert alert-warning mt-3"), id="results") # Ensure id matches target
131
  except Exception as e:
132
+ print(f"Error reading uploaded file: {e}")
133
+ # Return an error if reading fails
134
+ return Div(P(f"Error reading uploaded file: {e}", cls="alert alert-danger mt-3"), id="results") # Ensure id matches target
135
 
136
+ # --- Perform Prediction ---
137
  prediction, confidence = predict_image(img_bytes)
138
 
139
+ # --- Prepare Response ---
140
+ # Encode image to base64 to display a preview, only if prediction was okay
141
  img_src = None
142
+ if "Error" not in str(prediction): # Check if the prediction result indicates an error
143
  try:
144
  img_base64 = base64.b64encode(img_bytes).decode('utf-8')
145
+ # Try to use the provided content type, default if necessary or invalid
146
  content_type = file.content_type if file.content_type and file.content_type.startswith('image/') else 'image/jpeg'
147
  img_src = f"data:{content_type};base64,{img_base64}"
148
  except Exception as e:
149
  print(f"Error encoding image to base64: {e}")
150
+ # Log error, but proceed without image preview
151
 
152
+ # Determine result styling based on success or failure
153
+ result_cls = "alert alert-danger" if "Error" in str(prediction) else "alert alert-success"
 
154
 
155
+ # --- Return HTML Fragment ---
156
+ # This HTML will replace the content of the #results div
157
  return Div(
158
+ # Display image preview if available
159
+ (Img(src=img_src, alt="Uploaded Image Preview", style="max-width: 300px; max-height: 300px; margin-top: 15px; margin-bottom: 10px; display: block; border: 1px solid #ddd;") if img_src else P("Preview not available.")),
160
+ # Display prediction results or error message
161
  Div(cls=f"{result_cls} mt-3", role="alert",
162
  P(Strong("Prediction: "), f"{prediction}"),
163
+ # Only show confidence if prediction was successful
164
+ (P(Strong("Confidence: "), f"{confidence:.4f}") if "Error" not in str(prediction) else None)
165
+ ),
166
+ # Important: The root element returned should match the hx-target for replacement,
167
+ # or be structured such that the target is updated as intended. Here, we replace the entire #results div content.
168
+ id="results", # Adding id here ensures the target div itself is replaced if needed, though innerHTML swap is default
169
+ hx_swap_oob="true" # Example if you wanted to update multiple targets, not needed here for innerHTML swap.
170
  )
171
 
172
+
173
+ # --- Add CSS/JS Headers ---
174
+ # Include Bootstrap CSS for styling and JS for potential component interactions (like dropdowns, modals, etc., though not used here)
175
+ # FastHTML automatically includes HTMX
 
176
  app.sheets.append(
177
  Link(href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css", rel="stylesheet", integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN", crossorigin="anonymous")
178
  )
179
+ app.hdrs.append(
180
+ Script(src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/js/bootstrap.bundle.min.js", integrity="sha384-C6RzsynM9kWDrMNeT87bh95OGNyZPhcTNXj1NW7RuBCsyN/o0jlpcV8Qyq46cDfL", crossorigin="anonymous"),
181
+ )
182
 
183
+ # --- Run the App (for local development) ---
184
+ # This block is executed when you run `python app.py` directly.
185
+ # Hugging Face Spaces will use its own mechanism to run the 'app' object via an ASGI server like Uvicorn.
186
  if __name__ == "__main__":
187
+ print("Starting Uvicorn server for local development...")
188
+ # Use host="0.0.0.0" to make it accessible on your network
189
+ # Port 8000 is a common choice for web development
190
+ # reload=True automatically restarts the server when code changes (useful for development)
191
  uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)