# Streamlit YOLOv5 Model2X v0.1 # 创建人:曾逸夫 # 创建时间:2022-07-14 # 功能描述:多选,多项模型转换和打包下载 import os import shutil import time import zipfile import streamlit as st # 目录操作 def dir_opt(target_dir): if os.path.exists(target_dir): shutil.rmtree(target_dir) os.mkdir(target_dir) else: os.mkdir(target_dir) # 文件下载 def download_file(uploaded_file): # --------------- 下载 --------------- with open(f"{uploaded_file}", 'rb') as fmodel: # 读取转换的模型文件(pt2x) f_download_model = fmodel.read() st.download_button(label='下载转换后的模型', data=f_download_model, file_name=f"{uploaded_file}") fmodel.close() # 文件压缩 def zipDir(origin_dir, compress_file): # --------------- 压缩 --------------- zip = zipfile.ZipFile(f"{compress_file}", "w", zipfile.ZIP_DEFLATED) for path, dirnames, filenames in os.walk(f"{origin_dir}"): fpath = path.replace(f"{origin_dir}", '') for filename in filenames: zip.write(os.path.join(path, filename), os.path.join(fpath, filename)) zip.close() # params_include_list = ["torchscript", "onnx", "openvino", "coreml", "saved_model", "pb", "tflite", "tfjs"] def cb_opt(weight_name, btn_model_list, params_include_list): for i in range(len(btn_model_list)): if btn_model_list[i]: st.info(f"正在转换{params_include_list[i]}......") s = time.time() os.system(f'python export.py --weights ./weights/{weight_name} --include {params_include_list[i]}') e = time.time() st.success(f"{params_include_list[i]}转换完成,用时{round((e-s), 2)}秒") zipDir("./weights", "convert_weights.zip") # 打包weights目录,包括原始权重和转换后的权重 download_file("convert_weights.zip") # 下载打包文件 def main(): with st.container(): st.title("Streamlit YOLOv5 Model2X") st.subheader('创建人:曾逸夫(Zeng Yifu)') st.text("基于Streamlit的YOLOv5模型转换工具") st.write("-------------------------------------------------------------") dir_opt("./weights") uploaded_file = st.file_uploader("选择YOLOv5模型文件(.pt)") if uploaded_file is not None: # 读取上传的模型文件(.pt) weight_name = uploaded_file.name st.info(f"正在写入{weight_name}......") bytes_data = uploaded_file.getvalue() with open(f"./weights/{weight_name}", 'wb') as fb: fb.write(bytes_data) fb.close() st.success(f"{weight_name}写入成功!") st.text("请选择转换的类型:") cb_torchscript = st.checkbox('TorchScript') cb_onnx = st.checkbox('ONNX') cb_openvino = st.checkbox('OpenVINO') # cb_engine = st.checkbox('TensorRT') cb_coreml = st.checkbox('CoreML') cb_saved_model = st.checkbox('TensorFlow SavedModel') cb_pb = st.checkbox('TensorFlow GraphDef') cb_tflite = st.checkbox('TensorFlow Lite') # cb_edgetpu = st.checkbox('TensorFlow Edge TPU') cb_tfjs = st.checkbox('TensorFlow.js') btn_convert = st.button('转换') btn_model_list = [ cb_torchscript, cb_onnx, cb_openvino, cb_coreml, cb_saved_model, cb_pb, cb_tflite, cb_tfjs] params_include_list = [ "torchscript", "onnx", "openvino", "engine", "coreml", "saved_model", "pb", "tflite", "tfjs"] if btn_convert: cb_opt(weight_name, btn_model_list, params_include_list) st.write("-------------------------------------------------------------") if __name__ == "__main__": main()