Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
from polos.models import download_model, load_checkpoint | |
# モデルのロード | |
def load_model(): | |
model_path = download_model("polos") | |
model = load_checkpoint(model_path) | |
return model | |
model = load_model() | |
# Streamlitインターフェースの設定 | |
st.title('Polos Demo') | |
# セッションステートの初期化 | |
if 'image' not in st.session_state: | |
st.session_state.image = None | |
if 'user_input' not in st.session_state: | |
st.session_state.user_input = '' | |
if 'user_refs' not in st.session_state: | |
st.session_state.user_refs = [ | |
"there is a dog sitting on a couch with a person reaching out", | |
"a dog laying on a couch with a person", | |
'a dog is laying on a couch with a person' | |
] | |
if 'score' not in st.session_state: | |
st.session_state.score = None | |
# デフォルト画像の取得 | |
def get_default_image(): | |
try: | |
return Image.open("test.jpg").convert("RGB") | |
except FileNotFoundError: | |
return Image.new('RGB', (200, 200), color = 'gray') # デフォルト画像が見つからない場合の代替画像 | |
default_image = get_default_image() | |
# 画像アップロードのためのウィジェット | |
uploaded_image = st.file_uploader("Upload your image:", type=["jpg", "jpeg", "png"]) | |
if uploaded_image is not None: | |
st.session_state.image = Image.open(uploaded_image).convert("RGB") | |
elif st.session_state.image is None: | |
st.session_state.image = default_image | |
# 常に画像を表示 | |
st.image(st.session_state.image, caption="Displayed Image", use_column_width=True) | |
# 参照文の入力フィールド | |
user_refs = st.text_area("Enter reference sentences (separate each by a newline):", "\n".join(st.session_state.user_refs)) | |
st.session_state.user_refs = user_refs.split("\n") | |
# ユーザー入力のテキストフィールド | |
user_input = st.text_input("Enter the input sentence:", value=st.session_state.user_input) | |
st.session_state.user_input = user_input | |
# Computeボタン | |
if st.button('Compute'): | |
# データの準備 | |
data = [ | |
{ | |
"img": st.session_state.image, | |
"mt": st.session_state.user_input, | |
"refs": st.session_state.user_refs | |
} | |
] | |
# モデル予測 | |
if st.session_state.user_input: | |
_, scores = model.predict(data, batch_size=1, cuda=False) | |
st.session_state.score = scores[0] | |
# スコアの表示 | |
if st.session_state.score is not None: | |
st.metric(label="Score", value=f"{st.session_state.score:.5f}") | |