|
|
|
""" |
|
------------------------------------------------- |
|
@File Name: app.py |
|
@Author: Luyao.zhang |
|
@Date: 2023/5/15 |
|
@Description: |
|
------------------------------------------------- |
|
""" |
|
|
|
HuggingFace = True |
|
update_model_id = None |
|
base_download_path = "downloaded" |
|
|
|
if HuggingFace is False: |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
from pathlib import Path |
|
from PIL import Image |
|
import streamlit as st |
|
|
|
import config |
|
from utils import load_model, infer_uploaded_image, infer_uploaded_video, infer_uploaded_webcam |
|
import os |
|
|
|
query_params = st.experimental_get_query_params() |
|
|
|
|
|
if update_model_id is not None: |
|
if HuggingFace is False: |
|
model_name = os.getenv("m{}_name".format(update_model_id)) |
|
model_extname = os.getenv("m{}_type".format(update_model_id)) |
|
|
|
else: |
|
model_name = st.secrets["m{}_name".format(update_model_id)] |
|
model_extname = st.secrets["m{}_type".format(update_model_id)] |
|
|
|
path_model = os.path.join(base_download_path, model_name + model_extname) |
|
if os.path.exists(path_model): |
|
try: |
|
os.remove(path_model) |
|
except: |
|
print("Cannot remove", path_model) |
|
pass |
|
|
|
|
|
|
|
qmodel = 'Crowded-Human' |
|
if 'model' in query_params: |
|
qmodel = query_params['model'][0] |
|
|
|
|
|
st.set_page_config( |
|
page_title="YOLO.dog", |
|
page_icon="🤖", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
|
|
|
|
if not os.path.exists(base_download_path): |
|
try: |
|
os.makedirs(base_download_path) |
|
except: |
|
print("folder exists", base_download_path) |
|
pass |
|
|
|
if HuggingFace is False: |
|
model_count = int(os.getenv("model_count")) |
|
|
|
else: |
|
model_count = int(st.secrets["model_count"]) |
|
|
|
model_info = {} |
|
models_list = [] |
|
for i in range(0, model_count): |
|
if HuggingFace is False: |
|
model_name = os.getenv("m{}_name".format(i)) |
|
gdrive_id = os.getenv("m{}_griv".format(i)) |
|
model_extname = os.getenv("m{}_type".format(i)) |
|
model_desc = os.getenv("m{}_desc".format(i)) |
|
|
|
else: |
|
model_name = st.secrets["m{}_name".format(i)] |
|
gdrive_id = st.secrets["m{}_griv".format(i)] |
|
model_extname = st.secrets["m{}_type".format(i)] |
|
model_desc = st.secrets["m{}_desc".format(i)] |
|
|
|
print(i, model_name, gdrive_id, model_extname, model_desc) |
|
|
|
path_model = os.path.join(base_download_path, model_name + model_extname) |
|
print('path_model', path_model) |
|
model_info.update( {model_desc:path_model} ) |
|
models_list.append(model_desc) |
|
|
|
if not os.path.exists(path_model): |
|
download_link = "https://drive.google.com/uc?export=download&confirm=t&id={}".format(gdrive_id) |
|
|
|
|
|
print('wget -O {} --content-disposition "{}"'.format(path_model, download_link)) |
|
os.system( 'wget -O {} --content-disposition "{}"'.format(path_model, download_link)) |
|
|
|
|
|
|
|
|
|
if qmodel not in models_list: |
|
qmodel = models_list[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
st.sidebar.header("Model Config") |
|
|
|
|
|
task_type = "Detection" |
|
|
|
model_type = st.sidebar.selectbox( |
|
"Models list", |
|
models_list, |
|
index=models_list.index(qmodel) ) |
|
|
|
confidence = float(st.sidebar.slider( |
|
"Select Model Confidence", 10, 100, 5)) / 100 |
|
|
|
if model_type: |
|
st.header('{} Model Trial'.format(model_type),divider='rainbow') |
|
st.subheader('Use your :blue[photo/video] :sunglasses:') |
|
model_path = model_info[model_type] |
|
|
|
try: |
|
print('model_path', model_path) |
|
model = load_model(model_path) |
|
except Exception as e: |
|
st.error(f"Unable to load model. Please check the specified path: {model_path}") |
|
|
|
else: |
|
st.error("Please Select Model in Sidebar") |
|
|
|
|
|
st.sidebar.header("Image/Video Config") |
|
source_selectbox = st.sidebar.selectbox( |
|
"Select Source", |
|
config.SOURCES_LIST |
|
) |
|
|
|
source_img = None |
|
if source_selectbox == config.SOURCES_LIST[0]: |
|
infer_uploaded_image(confidence, model) |
|
elif source_selectbox == config.SOURCES_LIST[1]: |
|
infer_uploaded_video(confidence, model) |
|
elif source_selectbox == config.SOURCES_LIST[2]: |
|
infer_uploaded_webcam(confidence, model) |
|
else: |
|
st.error("Currently only 'Image' and 'Video' source are implemented") |
|
|