classify_blood / app.py
ninenox's picture
Update app.py
c0f45b2 verified
raw
history blame
1.59 kB
import gradio as gr
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
import torch
# โหลด model จากไฟล์ .pt
model = torch.load('model_blood.pt')
device = 'cpu' # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
# ข้อความ string ค่าของ classes ที่มี
targets = ['Negative','Positive']
# เตรียม data ก่อนเข้าโมเดล
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ฟังก์ชันประมวลผลรูปภาพ
def classify_image(img):
img = Image.fromarray(img.astype('uint8'), 'RGB')
img = transform(img).unsqueeze(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
img = img.to(device)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
confidences = {targets[i]: float(prediction[i]) for i in range(2)}
return confidences
# รันโปรแกรมเว็บ gradio
demo = gr.Interface(fn=classify_image,
inputs=gr.Image(width=224, height=224),
outputs=gr.Label(num_top_classes=2),
examples=["examples/negative.jpeg", "examples/positive.jpeg"]) # ภาพตัวอย่างมาจาก folder examples
demo.launch(share=True, debug=True)