FredZhang7
commited on
Commit
•
9e69b49
1
Parent(s):
66d76af
Upload model
Browse files
Model.py
CHANGED
@@ -103,7 +103,7 @@ class InceptionV3ModelForImageClassification(PreTrainedModel):
|
|
103 |
image = Image.open(path).convert('RGB')
|
104 |
return image
|
105 |
|
106 |
-
def predict(self, path, device="cuda"):
|
107 |
image = self.open_image(path)
|
108 |
image = self.transform(image)
|
109 |
image = image.unsqueeze(0)
|
@@ -116,6 +116,7 @@ class InceptionV3ModelForImageClassification(PreTrainedModel):
|
|
116 |
self.cpu()
|
117 |
with torch.no_grad():
|
118 |
out, aux = self(image)
|
119 |
-
|
|
|
120 |
_, predicted = torch.max(out.data, 1)
|
121 |
return self.config.classes[predicted.item()]
|
|
|
103 |
image = Image.open(path).convert('RGB')
|
104 |
return image
|
105 |
|
106 |
+
def predict(self, path, device="cuda", print_tensor=True):
|
107 |
image = self.open_image(path)
|
108 |
image = self.transform(image)
|
109 |
image = image.unsqueeze(0)
|
|
|
116 |
self.cpu()
|
117 |
with torch.no_grad():
|
118 |
out, aux = self(image)
|
119 |
+
if print_tensor:
|
120 |
+
print(out)
|
121 |
_, predicted = torch.max(out.data, 1)
|
122 |
return self.config.classes[predicted.item()]
|