Personal_Color / app.py
kikuepi's picture
Upload 6 files
0352c54 verified
raw
history blame
2.73 kB
from torchvision import models, transforms
from PIL import Image
import torch
import torch.nn as nn
import io
import streamlit as st
import time
st.title("パーソナルカラー予測")
SIZE = 224
MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)
transform = transforms.Compose([
transforms.Resize((SIZE, SIZE)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
model = models.resnet152(pretrained=True)
n_classes = 4
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('Resnet_2024_0214_version1', map_location=device))
model.to(device)
model.eval()
view_flag = True
skip = False
left_column, right_column = st.columns(2)
def predict_image(img):
img = img.convert('RGB')
img_transformed = transform(img)
inputs = img_transformed.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
return preds.item()
uploaded_file = st.file_uploader('Choose an image...', type=['jpg', 'png'])
if uploaded_file:
with left_column:
img = Image.open(uploaded_file)
st.image(img, caption="Uploaded Image", use_column_width=True)
pred = predict_image(img)
if pred == 0:
season_type = "秋"
elif pred == 1:
season_type = "春"
elif pred == 2:
season_type = "夏"
else:
season_type = "冬"
if 'show_video' not in st.session_state:
st.session_state.show_video = False
if 'skip' not in st.session_state:
st.session_state.skip = False
person_result = st.selectbox(
'結果の修正',
('春', '夏', '秋', '冬')
)
if person_result:
season_type = person_result
with left_column:
st.write(f"パーソナルカラーは {season_type} です")
# おすすめの商品を見るボタン
view_recommend = st.button("おすすめの商品を見る")
with right_column:
if view_recommend:
st.session_state.show_video = True
# スキップされずに動画を表示する
if st.session_state.show_video and not st.session_state.skip:
st.video("sample.mp4", start_time=0)
# 動画のスキップボタンが押されたら、レコメンド文を表示する
if st.session_state.show_video and not st.session_state.skip:
time.sleep(5)
st.session_state.show_video = False
st.session_state.skip = True
if st.session_state.skip:
st.write(
"""
レコメンド文です。
""")