tranquilkd commited on
Commit
71413cc
·
1 Parent(s): 140cb56

First commit

Browse files
Files changed (5) hide show
  1. .gitattributes +0 -0
  2. README.md +0 -0
  3. app.py +85 -0
  4. best.pth +3 -0
  5. requirement.txt +3 -0
.gitattributes CHANGED
File without changes
README.md CHANGED
File without changes
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import traceback
3
+ import gradio as gr
4
+ import torch
5
+ from torchvision.models import get_model
6
+ from torchvision.transforms import v2
7
+ from torchvision.transforms.functional import InterpolationMode
8
+
9
+
10
+ # Imagenet-1k classes
11
+ if not os.path.exists("imagenet_classes.txt"):
12
+ os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
13
+
14
+ # Download an example image from the pytorch website
15
+ if not os.path.exists("dog.jpg"):
16
+ torch.hub.download_url_to_file("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
17
+
18
+
19
+ # Function to load the model with custom weights
20
+ def load_model(weights_path):
21
+ model = get_model("resnet50", num_classes=1000)
22
+ ckpt = torch.load(weights_path, map_location=torch.device("cpu"))
23
+ model.load_state_dict(ckpt["model_state_dict"])
24
+ model.eval()
25
+ return model
26
+
27
+
28
+ # Function for making predictions and returning top 5 predictions with confidence
29
+ def classify_image(image):
30
+ # Preprocess the input image
31
+ image = transform(image).unsqueeze(0) # Add batch dimension
32
+
33
+ with torch.no_grad():
34
+ output = model(image) # Get model output
35
+
36
+ # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
37
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
38
+
39
+ # Read the categories
40
+ with open("imagenet_classes.txt", "r") as f:
41
+ categories = [s.strip() for s in f.readlines()]
42
+
43
+ # Show top categories per image
44
+ top5_prob, top5_catid = torch.topk(probabilities, 5)
45
+ result = {}
46
+ for i in range(top5_prob.size(0)):
47
+ result[categories[top5_catid[i]]] = top5_prob[i].item()
48
+ return result
49
+
50
+ # Define image transformation to match the model input
51
+ transform = v2.Compose([
52
+ v2.Resize(256, interpolation=InterpolationMode.BILINEAR, antialias=True),
53
+ v2.CenterCrop(224),
54
+ v2.PILToTensor(),
55
+ v2.ToDtype(torch.float, scale=True),
56
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
57
+ v2.ToPureTensor(),
58
+ ])
59
+
60
+ # Path to the pre-trained model weights (should be set by the user)
61
+ model_weights_path = "best.pth"
62
+ model = load_model(model_weights_path)
63
+
64
+ # Define the Gradio interface
65
+ iface = gr.Interface(
66
+ fn=classify_image, # The function to run on input
67
+ inputs=gr.Image(type="pil"), # Image input (in PIL format)
68
+ outputs=gr.Label(num_top_classes=5), # Output will be the predicted top 5 classes with confidence scores
69
+ title = "Image Recognition using ResNet-50 trained on Imagenet-1K",
70
+ description = "<p style='text-align: center'> Gradio demo for ResNet, Deep residual networks pre-trained on ImageNet. To use it, simply upload your image, or click one of the examples to load them. </p>",
71
+ article = "<p style='text-align: center'> \
72
+ <a href='https://arxiv.org/abs/1512.03385' target='_blank'>Deep Residual Learning for Image Recognition</a> | \
73
+ <a href='https://github.com/KD1994/session-9-imagenet-resnet50' target='_blank'>Github Repo</a> \
74
+ </p>",
75
+ examples = [
76
+ ['dog.jpg']
77
+ ]
78
+ )
79
+
80
+ # Add error handling to launch
81
+ try:
82
+ iface.launch(share=True)
83
+ except Exception as e:
84
+ print(f"Error launching interface: {str(e)}")
85
+ print(traceback.format_exc())
best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6fdff2ad1f20d1c622f84375e1ec1dbdecd8c8f5488beade445bbfa6509e0fd
3
+ size 204794470
requirement.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==5.9.1
2
+ torch==2.3.1
3
+ torchvision==0.18.1