Polos-Demo / app.py
yuwd's picture
update
a005919
raw
history blame
2.58 kB
import streamlit as st
from PIL import Image
from polos.models import download_model, load_checkpoint
# モデルのロード
@st.cache_resource()
def load_model():
model_path = download_model("polos")
model = load_checkpoint(model_path)
return model
model = load_model()
# Streamlitインターフェースの設定
st.title('Polos Demo')
# セッションステートの初期化
if 'image' not in st.session_state:
st.session_state.image = None
if 'user_input' not in st.session_state:
st.session_state.user_input = ''
if 'user_refs' not in st.session_state:
st.session_state.user_refs = [
"there is a dog sitting on a couch with a person reaching out",
"a dog laying on a couch with a person",
'a dog is laying on a couch with a person'
]
if 'score' not in st.session_state:
st.session_state.score = None
# デフォルト画像の取得
@st.cache_resource()
def get_default_image():
try:
return Image.open("test.jpg").convert("RGB")
except FileNotFoundError:
return Image.new('RGB', (200, 200), color = 'gray') # デフォルト画像が見つからない場合の代替画像
default_image = get_default_image()
# 画像アップロードのためのウィジェット
uploaded_image = st.file_uploader("Upload your image:", type=["jpg", "jpeg", "png"])
if uploaded_image is not None:
st.session_state.image = Image.open(uploaded_image).convert("RGB")
elif st.session_state.image is None:
st.session_state.image = default_image
# 常に画像を表示
st.image(st.session_state.image, caption="Displayed Image", use_column_width=True)
# 参照文の入力フィールド
user_refs = st.text_area("Enter reference sentences (separate each by a newline):", "\n".join(st.session_state.user_refs))
st.session_state.user_refs = user_refs.split("\n")
# ユーザー入力のテキストフィールド
user_input = st.text_input("Enter the input sentence:", value=st.session_state.user_input)
st.session_state.user_input = user_input
# Computeボタン
if st.button('Compute'):
# データの準備
data = [
{
"img": st.session_state.image,
"mt": st.session_state.user_input,
"refs": st.session_state.user_refs
}
]
# モデル予測
if st.session_state.user_input:
_, scores = model.predict(data, batch_size=1, cuda=False)
st.session_state.score = scores[0]
# スコアの表示
if st.session_state.score is not None:
st.metric(label="Score", value=f"{st.session_state.score:.5f}")