Spaces:
Runtime error
Runtime error
from torchvision import models, transforms | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import io | |
import streamlit as st | |
import time | |
st.title("パーソナルカラー診断AI") | |
SIZE = 224 | |
MEAN = (0.485, 0.456, 0.406) | |
STD = (0.229, 0.224, 0.225) | |
transform = transforms.Compose([ | |
transforms.Resize((SIZE, SIZE)), | |
transforms.ToTensor(), | |
transforms.Normalize(MEAN, STD), | |
]) | |
model = models.resnet152(pretrained=True) | |
n_classes = 4 | |
num_ftrs = model.fc.in_features | |
model.fc = nn.Linear(num_ftrs, n_classes) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.load_state_dict(torch.load('Resnet_2024_0214_version1', map_location=device)) | |
model.to(device) | |
model.eval() | |
view_flag = True | |
skip = False | |
def predict_image(img): | |
img = img.convert('RGB') | |
img_transformed = transform(img) | |
inputs = img_transformed.unsqueeze(0).to(device) | |
with torch.no_grad(): | |
outputs = model(inputs) | |
_, preds = torch.max(outputs, 1) | |
return preds.item() | |
uploaded_file = st.file_uploader('Choose an image...', type=['jpg', 'png']) | |
if uploaded_file: | |
img = Image.open(uploaded_file) | |
st.image(img, caption="Uploaded Image", use_column_width=True) | |
pred = predict_image(img) | |
if pred == 0: | |
season_type = "秋" | |
elif pred == 1: | |
season_type = "春" | |
elif pred == 2: | |
season_type = "夏" | |
else: | |
season_type = "冬" | |
if 'show_video' not in st.session_state: | |
st.session_state.show_video = False | |
if 'skip' not in st.session_state: | |
st.session_state.skip = False | |
if 'result' not in st.session_state: | |
st.session_state.result = False | |
st.write(f"パーソナルカラー診断結果:{season_type} ") | |
st.write("あなたにおすすめの色はこちらです") | |
st.session_state.result = True | |
st.image(f"{season_type}.png") | |
st.write( | |
""" | |
あなたにおすすめの商品はこちらです | |
""") | |
st.image("服.png") |