nragrawal commited on
Commit
33566fb
·
1 Parent(s): 5bf67e1

Update username in app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -22
app.py CHANGED
@@ -3,15 +3,22 @@ from transformers import AutoModelForImageClassification
3
  import torch
4
  import torchvision.transforms as transforms
5
  from PIL import Image
 
 
6
 
7
  # Load model from Hub instead of local file
8
  def load_model():
9
- model = AutoModelForImageClassification.from_pretrained(
10
- "nragrawal/resnet-imagenet",
11
- trust_remote_code=True
12
- )
13
- model.eval()
14
- return model
 
 
 
 
 
15
 
16
  # Preprocessing
17
  transform = transforms.Compose([
@@ -24,28 +31,39 @@ transform = transforms.Compose([
24
 
25
  # Inference function
26
  def predict(image):
27
- model = load_model()
28
-
29
- # Preprocess image
30
- img = Image.fromarray(image)
31
- img = transform(img).unsqueeze(0)
32
-
33
- # Inference
34
- with torch.no_grad():
35
- output = model(img)
36
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
37
 
38
- # Get top 5 predictions
39
- top5_prob, top5_catid = torch.topk(probabilities, 5)
40
- return {f"Class {i}": float(prob) for i, prob in zip(top5_catid, top5_prob)}
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Create Gradio interface
43
  iface = gr.Interface(
44
  fn=predict,
45
  inputs=gr.Image(),
46
  outputs=gr.Label(num_top_classes=5),
47
  title="ResNet Image Classification",
48
- description="Upload an image to classify it using ResNet"
 
49
  )
50
 
51
- iface.launch()
 
 
 
 
 
 
3
  import torch
4
  import torchvision.transforms as transforms
5
  from PIL import Image
6
+ import traceback
7
+ import sys
8
 
9
  # Load model from Hub instead of local file
10
  def load_model():
11
+ try:
12
+ model = AutoModelForImageClassification.from_pretrained(
13
+ "nragrawal/resnet-imagenet",
14
+ trust_remote_code=True
15
+ )
16
+ model.eval()
17
+ return model
18
+ except Exception as e:
19
+ print(f"Error loading model: {str(e)}")
20
+ print(traceback.format_exc())
21
+ raise e
22
 
23
  # Preprocessing
24
  transform = transforms.Compose([
 
31
 
32
  # Inference function
33
  def predict(image):
34
+ try:
35
+ model = load_model()
 
 
 
 
 
 
 
 
36
 
37
+ # Preprocess image
38
+ img = Image.fromarray(image)
39
+ img = transform(img).unsqueeze(0)
40
+
41
+ # Inference
42
+ with torch.no_grad():
43
+ output = model(img)
44
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
45
+
46
+ # Get top 5 predictions
47
+ top5_prob, top5_catid = torch.topk(probabilities, 5)
48
+ return {f"Class {i}": float(prob) for i, prob in zip(top5_catid, top5_prob)}
49
+ except Exception as e:
50
+ print(f"Error during prediction: {str(e)}")
51
+ print(traceback.format_exc())
52
+ return {"error": str(e)}
53
 
54
+ # Create Gradio interface with error handling
55
  iface = gr.Interface(
56
  fn=predict,
57
  inputs=gr.Image(),
58
  outputs=gr.Label(num_top_classes=5),
59
  title="ResNet Image Classification",
60
+ description="Upload an image to classify it using ResNet",
61
+ allow_flagging="never"
62
  )
63
 
64
+ # Add error handling to launch
65
+ try:
66
+ iface.launch()
67
+ except Exception as e:
68
+ print(f"Error launching interface: {str(e)}")
69
+ print(traceback.format_exc())