wandikafp commited on
Commit
3e71998
·
1 Parent(s): de2dbac
Files changed (2) hide show
  1. app.py +46 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ # Load your resnet18 model from Hugging Face
8
+ model = models.resnet18()
9
+ model.fc = nn.Linear(model.fc.in_features, 4) # Assuming 4 classes
10
+ checkpoint = torch.hub.load_state_dict_from_url(
11
+ 'https://huggingface.co/wandikafp/resnet18-tom-and-jerry-classifier/resolve/main/pytorch_model.bin',
12
+ map_location=torch.device('cpu')
13
+ )
14
+ model.load_state_dict(checkpoint)
15
+ model.eval()
16
+
17
+ # Define image transformations
18
+ transform = transforms.Compose([
19
+ transforms.Resize((224, 224)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
+ ])
23
+
24
+ # Define a prediction function
25
+ def classify_image(image):
26
+ image = Image.fromarray(image) # Convert to PIL image
27
+ image = transform(image).unsqueeze(0) # Preprocess the image
28
+
29
+ with torch.no_grad():
30
+ outputs = model(image)
31
+ _, predicted = torch.max(outputs, 1)
32
+
33
+ labels = ['tom', 'jerry', 'tom_jerry_0', 'tom_jerry_1']
34
+ return labels[predicted.item()]
35
+
36
+ # Create Gradio interface
37
+ interface = gr.Interface(
38
+ fn=classify_image,
39
+ inputs="image",
40
+ outputs="label",
41
+ title="Tom and Jerry Classifier",
42
+ description="Classify images as 'tom', 'jerry', 'tom_jerry_0', or 'tom_jerry_1'."
43
+ )
44
+
45
+ # Launch the Gradio app
46
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio