File size: 957 Bytes
84c4b50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import torch
class RAMPlusModel:
def __init__(self):
self.feature_extractor = AutoFeatureExtractor.from_pretrained("xinyu1205/recognize-anything-plus-model")
self.model = AutoModelForImageClassification.from_pretrained("xinyu1205/recognize-anything-plus-model")
self.model.eval()
def predict(self, image):
inputs = self.feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
predicted_classes = logits.argmax(-1)
# ์์ 5๊ฐ ํ๊ทธ ๋ฐํ (์ด ๋ถ๋ถ์ ๋ชจ๋ธ์ ์ค์ ์ถ๋ ฅ์ ๋ฐ๋ผ ์กฐ์ ํ์)
top_5 = torch.topk(logits, k=5)
return [self.model.config.id2label[i.item()] for i in top_5.indices[0]]
# ๋ชจ๋ธ ์ธ์คํด์ค ์์ฑ
model = RAMPlusModel() |