lyimo commited on
Commit
7ae5b25
·
verified ·
1 Parent(s): 3d2023e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -26
app.py CHANGED
@@ -1,30 +1,161 @@
1
- import gradio as gr
 
 
 
 
 
 
 
2
  from fastai.vision.all import *
3
- import skimage
4
-
5
-
6
- learn = load_learner('model.pkl')
7
- #from huggingface_hub import from_pretrained_fastai
8
- #learn = from_pretrained_fastai("devdatanalytics/commonbean")
9
-
10
- labels = learn.dls.vocab
11
- def predict(img):
12
- img = PILImage.create(img)
13
- pred,pred_idx,probs = learn.predict(img)
14
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
15
-
16
- #title = "Common beans diseases classfier"
17
- #description = "An app for Common beans diseases Classisfication"
18
- #article="<p style='text-align: center'>The app identifies and classifies common beans diseases: Anthracnose and Bean rust.</p>"
19
- # Create the Gradio interface
20
- interface = gr.Interface(
21
- fn=predict,
22
- inputs=gr.Image(),
23
- outputs=gr.Label(num_top_classes=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
 
 
25
 
26
- # Enable the queue to handle POST requests
27
- interface.queue(api_open=True)
28
 
29
- # Launch the interface
30
- interface.launch()
 
 
 
 
 
1
+ # app.py
2
+
3
+ import uvicorn
4
+ import base64
5
+ import io
6
+ from pathlib import Path
7
+
8
+ # Fastai imports
9
  from fastai.vision.all import *
10
+
11
+ # FastHTML imports
12
+ from fasthtml.common import * # Imports common HTML tags, App, etc.
13
+ from fasthtml.core import FastHTML # Base class if needed, but App is usually sufficient
14
+ from fasthtml.components import FileInput # Specific component for file input
15
+ from fastcore.utils import * # For Upload class
16
+
17
+ # --- Configuration ---
18
+ # Ensure the path to your exported model is correct
19
+ # When deploying to HF Spaces, this relative path should work if model.pkl is in the same directory
20
+ MODEL_PATH = Path(__file__).parent / 'model.pkl'
21
+ # Set device (CPU is usually the default/safest for HF free tier)
22
+ defaults.device = torch.device('cpu')
23
+
24
+ # --- Load Fastai Learner ---
25
+ try:
26
+ print(f"Loading model from: {MODEL_PATH}")
27
+ learn = load_learner(MODEL_PATH)
28
+ print("Model loaded successfully.")
29
+ # Get class names (vocab) from the learner's dataloaders
30
+ CLASS_NAMES = learn.dls.vocab
31
+ print(f"Model Classes: {CLASS_NAMES}")
32
+ except FileNotFoundError:
33
+ print(f"Error: Model file not found at {MODEL_PATH}")
34
+ print("Please make sure 'model.pkl' is in the same directory or update MODEL_PATH.")
35
+ # In a deployed environment, you might want to raise the error or handle it differently
36
+ # For now, we exit if the model isn't found on startup.
37
+ raise SystemExit(f"Error: Model file not found at {MODEL_PATH}")
38
+ except Exception as e:
39
+ print(f"Error loading the model: {e}")
40
+ raise SystemExit(f"Error loading the model: {e}")
41
+
42
+ # --- FastHTML App Setup ---
43
+ # FastHTML automatically finds this 'app' object when run with uvicorn app:app
44
+ app = FastHTML()
45
+ rt = app.route # Route decorator
46
+
47
+ # --- Helper Function for Prediction ---
48
+ def predict_image(img_bytes: bytes):
49
+ """Takes image bytes, predicts using the fastai model."""
50
+ try:
51
+ img = PILImage.create(img_bytes)
52
+ pred_class, pred_idx, probs = learn.predict(img)
53
+ confidence = probs[pred_idx].item() # Get the probability of the predicted class
54
+ return pred_class, confidence
55
+ except Exception as e:
56
+ print(f"Error during prediction: {e}")
57
+ # Return a user-friendly error message and neutral confidence
58
+ return f"Prediction Error: {e}", 0.0
59
+
60
+ # --- Define Routes ---
61
+
62
+ @rt("/")
63
+ async def get(request):
64
+ """Serves the main page with the upload form."""
65
+ # Using Bootstrap classes for basic styling
66
+ return Titled("Fastai Image Classifier",
67
+ Main(cls="container mt-4",
68
+ H1("Upload an Image for Classification"),
69
+ # Form for uploading the image
70
+ Form(
71
+ Div(cls="mb-3",
72
+ # Label("Choose Image", fr="fileInput", cls="form-label"), # Optional label
73
+ FileInput(name="file", id="fileInput", cls="form-control", required=True), # Added required
74
+ ),
75
+ Button("Classify Image", type="submit", cls="btn btn-primary"), # Submit button
76
+ # HTMX attributes for form submission
77
+ hx_post="/predict", # POST request to /predict
78
+ hx_target="#results", # Put the response into the #results div
79
+ hx_swap="innerHTML", # Replace the content of #results
80
+ hx_encoding="multipart/form-data", # Needed for file uploads
81
+ # Add indicator for user feedback during processing
82
+ hx_indicator="#loading-spinner",
83
+ id="upload-form"
84
+ ),
85
+ # Loading indicator (hidden by default)
86
+ Div(id="loading-spinner", cls="htmx-indicator spinner-border mt-3", role="status",
87
+ Span("Loading...", cls="visually-hidden")
88
+ ),
89
+ # Div where results will be displayed
90
+ Div(id="results", cls="mt-4")
91
+ )
92
+ )
93
+
94
+ @rt("/predict", methods=["POST"])
95
+ async def post(request, file: Upload):
96
+ """Handles image upload, prediction, and returns results."""
97
+ if not file or not file.filename:
98
+ return P("No file uploaded. Please select a file.", cls="alert alert-warning")
99
+
100
+ # Check if it's likely an image file (basic check)
101
+ allowed_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'}
102
+ file_ext = Path(file.filename).suffix.lower()
103
+ if file_ext not in allowed_extensions:
104
+ return P(f"Invalid file type: {file_ext}. Please upload an image (png, jpg, jpeg, gif, bmp, webp).", cls="alert alert-danger")
105
+
106
+ print(f"Received file: {file.filename}, Content-Type: {file.content_type}")
107
+
108
+ # Read image bytes
109
+ try:
110
+ img_bytes = await file.read() # Use await for async file reading
111
+ if not img_bytes:
112
+ return P("Uploaded file is empty.", cls="alert alert-warning")
113
+ except Exception as e:
114
+ print(f"Error reading file: {e}")
115
+ return P(f"Error reading uploaded file: {e}", cls="alert alert-danger")
116
+
117
+ # Perform prediction
118
+ prediction, confidence = predict_image(img_bytes)
119
+
120
+ # Encode image to base64 to display it back
121
+ img_src = None
122
+ if "Error" not in prediction: # Only try to display image if prediction didn't fail critically
123
+ try:
124
+ img_base64 = base64.b64encode(img_bytes).decode('utf-8')
125
+ # Try to use the provided content type, default if necessary
126
+ content_type = file.content_type if file.content_type and file.content_type.startswith('image/') else 'image/jpeg'
127
+ img_src = f"data:{content_type};base64,{img_base64}"
128
+ except Exception as e:
129
+ print(f"Error encoding image to base64: {e}")
130
+ # Don't display image if encoding fails, but still show prediction
131
+
132
+ # Return the results as HTML fragment
133
+ # Using Bootstrap alert classes for results
134
+ result_cls = "alert alert-danger" if "Error" in prediction else "alert alert-success"
135
+
136
+ return Div(
137
+ (Img(src=img_src, alt="Uploaded Image", style="max-width: 300px; max-height: 300px; margin-top: 15px; margin-bottom: 10px; display: block;") if img_src else P("Preview not available.")),
138
+ Div(cls=f"{result_cls} mt-3", role="alert",
139
+ P(Strong("Prediction: "), f"{prediction}"),
140
+ P(Strong("Confidence: "), f"{confidence:.4f}") if "Error" not in prediction else ""
141
+ )
142
+ )
143
+
144
+ # --- Add CSS/JS ---
145
+ # Add Bootstrap CSS and JS for styling and components (like the spinner)
146
+ app.hdrs.append(
147
+ Script(src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js", integrity="sha384-C6RzsynM9kWDrMNeT87bh95OGNyZPhcTNXj1NW7RuBCsyN/o0jlpcV8Qyq46cDfL", crossorigin="anonymous"),
148
+ )
149
+ app.sheets.append(
150
+ Link(href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css", rel="stylesheet", integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN", crossorigin="anonymous")
151
  )
152
+ # Add HTMX itself (FastHTML includes a minimal version, but sometimes the full CDN is useful)
153
+ # app.hdrs.append(Script(src="https://unpkg.com/[email protected]/dist/htmx.min.js"))
154
 
 
 
155
 
156
+ # --- Run the App (for local testing) ---
157
+ # This part is mainly for running locally with `python app.py`
158
+ # When deployed on Hugging Face Spaces, Uvicorn is run automatically based on the README config
159
+ if __name__ == "__main__":
160
+ # Use port 8000 which is often standard for these deployments
161
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)