Spaces:
Runtime error
Runtime error
from torchvision import models, transforms | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import io | |
import streamlit as st | |
import time | |
st.title("パーソナルカラー予測") | |
SIZE = 224 | |
MEAN = (0.485, 0.456, 0.406) | |
STD = (0.229, 0.224, 0.225) | |
transform = transforms.Compose([ | |
transforms.Resize((SIZE, SIZE)), | |
transforms.ToTensor(), | |
transforms.Normalize(MEAN, STD), | |
]) | |
model = models.resnet152(pretrained=True) | |
n_classes = 4 | |
num_ftrs = model.fc.in_features | |
model.fc = nn.Linear(num_ftrs, n_classes) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.load_state_dict(torch.load('Resnet_2024_0214_version1', map_location=device)) | |
model.to(device) | |
model.eval() | |
view_flag = True | |
skip = False | |
left_column, right_column = st.columns(2) | |
def predict_image(img): | |
img = img.convert('RGB') | |
img_transformed = transform(img) | |
inputs = img_transformed.unsqueeze(0).to(device) | |
with torch.no_grad(): | |
outputs = model(inputs) | |
_, preds = torch.max(outputs, 1) | |
return preds.item() | |
uploaded_file = st.file_uploader('Choose an image...', type=['jpg', 'png']) | |
if uploaded_file: | |
with left_column: | |
img = Image.open(uploaded_file) | |
st.image(img, caption="Uploaded Image", use_column_width=True) | |
pred = predict_image(img) | |
if pred == 0: | |
season_type = "秋" | |
elif pred == 1: | |
season_type = "春" | |
elif pred == 2: | |
season_type = "夏" | |
else: | |
season_type = "冬" | |
if 'show_video' not in st.session_state: | |
st.session_state.show_video = False | |
if 'skip' not in st.session_state: | |
st.session_state.skip = False | |
person_result = st.selectbox( | |
'結果の修正', | |
('春', '夏', '秋', '冬') | |
) | |
if person_result: | |
season_type = person_result | |
with left_column: | |
st.write(f"パーソナルカラーは {season_type} です") | |
# おすすめの商品を見るボタン | |
view_recommend = st.button("おすすめの商品を見る") | |
with right_column: | |
if view_recommend: | |
st.session_state.show_video = True | |
# スキップされずに動画を表示する | |
if st.session_state.show_video and not st.session_state.skip: | |
st.video("sample.mp4", start_time=0) | |
# 動画のスキップボタンが押されたら、レコメンド文を表示する | |
if st.session_state.show_video and not st.session_state.skip: | |
time.sleep(5) | |
st.session_state.show_video = False | |
st.session_state.skip = True | |
if st.session_state.skip: | |
st.write( | |
""" | |
レコメンド文です。 | |
""") |