tranquilkd commited on
Commit
1615b9d
·
verified ·
1 Parent(s): 599c421

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -18
app.py CHANGED
@@ -3,19 +3,10 @@ import traceback
3
  import gradio as gr
4
  import torch
5
  from torchvision.models import get_model
6
- from torchvision.transforms import v2
7
  from torchvision.transforms.functional import InterpolationMode
8
 
9
 
10
- # Imagenet-1k classes
11
- if not os.path.exists("imagenet_classes.txt"):
12
- os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
13
-
14
- # Download an example image from the pytorch website
15
- if not os.path.exists("dog.jpg"):
16
- torch.hub.download_url_to_file("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
17
-
18
-
19
  # Function to load the model with custom weights
20
  def load_model(weights_path):
21
  model = get_model("resnet50", num_classes=1000)
@@ -48,13 +39,12 @@ def classify_image(image):
48
  return result
49
 
50
  # Define image transformation to match the model input
51
- transform = v2.Compose([
52
- v2.Resize(256, interpolation=InterpolationMode.BILINEAR, antialias=True),
53
- v2.CenterCrop(224),
54
- v2.PILToTensor(),
55
- v2.ToDtype(torch.float, scale=True),
56
- v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
57
- v2.ToPureTensor(),
58
  ])
59
 
60
  # Path to the pre-trained model weights (should be set by the user)
@@ -67,13 +57,17 @@ iface = gr.Interface(
67
  inputs=gr.Image(type="pil"), # Image input (in PIL format)
68
  outputs=gr.Label(num_top_classes=5), # Output will be the predicted top 5 classes with confidence scores
69
  title = "Image Recognition using ResNet-50 trained on Imagenet-1K",
 
70
  description = "<p style='text-align: center'> Gradio demo for ResNet, Deep residual networks pre-trained on ImageNet. To use it, simply upload your image, or click one of the examples to load them. </p>",
71
  article = "<p style='text-align: center'> \
72
  <a href='https://arxiv.org/abs/1512.03385' target='_blank'>Deep Residual Learning for Image Recognition</a> | \
73
  <a href='https://github.com/KD1994/session-9-imagenet-resnet50' target='_blank'>Github Repo</a> \
74
  </p>",
75
  examples = [
76
- ['dog.jpg']
 
 
 
77
  ]
78
  )
79
 
 
3
  import gradio as gr
4
  import torch
5
  from torchvision.models import get_model
6
+ from torchvision.transforms import transforms
7
  from torchvision.transforms.functional import InterpolationMode
8
 
9
 
 
 
 
 
 
 
 
 
 
10
  # Function to load the model with custom weights
11
  def load_model(weights_path):
12
  model = get_model("resnet50", num_classes=1000)
 
39
  return result
40
 
41
  # Define image transformation to match the model input
42
+ transform = transforms.Compose([
43
+ transforms.Resize(256, interpolation=InterpolationMode.BILINEAR, antialias=True),
44
+ transforms.CenterCrop(224),
45
+ transforms.PILToTensor(),
46
+ transforms.ConvertImageDtype(torch.float),
47
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
48
  ])
49
 
50
  # Path to the pre-trained model weights (should be set by the user)
 
57
  inputs=gr.Image(type="pil"), # Image input (in PIL format)
58
  outputs=gr.Label(num_top_classes=5), # Output will be the predicted top 5 classes with confidence scores
59
  title = "Image Recognition using ResNet-50 trained on Imagenet-1K",
60
+ live = True,
61
  description = "<p style='text-align: center'> Gradio demo for ResNet, Deep residual networks pre-trained on ImageNet. To use it, simply upload your image, or click one of the examples to load them. </p>",
62
  article = "<p style='text-align: center'> \
63
  <a href='https://arxiv.org/abs/1512.03385' target='_blank'>Deep Residual Learning for Image Recognition</a> | \
64
  <a href='https://github.com/KD1994/session-9-imagenet-resnet50' target='_blank'>Github Repo</a> \
65
  </p>",
66
  examples = [
67
+ ['examples/dog.jpg'],
68
+ ['examples/great-white-shark.jpg'],
69
+ ['examples/american-goldfinch.jpg'],
70
+ ['examples/hognose-snake.jpg']
71
  ]
72
  )
73