AItool commited on
Commit
9d4f900
·
verified ·
1 Parent(s): 222d735

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -28
app.py CHANGED
@@ -1,37 +1,36 @@
1
- __all__ = ['learn', 'classify_image', 'categories', 'classifier', 'virtual','image', 'label', 'examples', 'intf']
 
 
2
 
3
- # Cell
4
  from fastai.vision.all import *
5
  import gradio as gr
6
- import timm
7
- import pickle
8
- import torch
9
 
10
- # with open('./model.pkl', 'rb') as f:
11
- # model = pickle.load(f)
12
 
13
- with open('model.pkl', 'rb') as f:
14
- model = pickle.load(f)
15
-
16
- def is_real(x): return x[0].isupper()
17
-
18
- #|export
19
- learn = load_learner('model.pkl')
20
-
21
- #|export
22
- categories =('Virtual Staging','Real')
23
 
 
24
  def classify_image(img):
25
- pred,idx,probs = learn.predict(img)
26
- return dict(zip(categories,map(float,probs)))
 
27
 
28
- #*** We have to cast to float above because KAGGLE does not return number on the answer it returns tensors, and Gradio does not deal with numpy so we have to cast to float
29
-
30
- #|export
31
- image = gr.inputs.Image(shape=(192,192))
32
  label = gr.outputs.Label()
33
- examples = ['virtual.jpg','real.jpg']
34
-
35
- # Cell
36
- intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=examples,share=True)
37
- intf.launch()
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ 'learn', 'classify_image', 'categories', 'image', 'label', 'examples', 'intf'
3
+ ]
4
 
 
5
  from fastai.vision.all import *
6
  import gradio as gr
7
+ import timm # keep only if your model actually needs timm
 
 
8
 
9
+ # ✅ Load the fastai model properly no manual pickle.load()
10
+ learn = load_learner('model.pkl', cpu=True)
11
 
12
+ # Define your categories exactly as trained
13
+ categories = ('Virtual Staging', 'Real')
 
 
 
 
 
 
 
 
14
 
15
+ # Prediction function for Gradio
16
  def classify_image(img):
17
+ pred, idx, probs = learn.predict(img)
18
+ # Cast to float so Gradio handles them cleanly
19
+ return dict(zip(categories, map(float, probs)))
20
 
21
+ # Gradio UI components
22
+ image = gr.inputs.Image(shape=(192, 192))
 
 
23
  label = gr.outputs.Label()
24
+ examples = ['virtual.jpg', 'real.jpg'] # sample files in your Space
25
+
26
+ # Create and launch interface
27
+ intf = gr.Interface(
28
+ fn=classify_image,
29
+ inputs=image,
30
+ outputs=label,
31
+ examples=examples,
32
+ share=True
33
+ )
34
+
35
+ if __name__ == "__main__":
36
+ intf.launch()