Spaces:
Paused
Paused
import streamlit as st | |
from huggingface_hub import HfApi | |
import os | |
import subprocess | |
HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN") | |
HF_USERNAME = ( | |
st.secrets.get("HF_USERNAME") | |
or os.environ.get("HF_USERNAME") | |
or os.environ.get("SPACE_AUTHOR_NAME") | |
) | |
TRANSFORMERS_REPOSITORY_URL = "https://github.com/xenova/transformers.js" | |
TRANSFORMERS_REPOSITORY_REVISION = "2.16.0" | |
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js" | |
HF_BASE_URL = "https://huggingface.co" | |
if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH): | |
os.system(f"git clone {TRANSFORMERS_REPOSITORY_URL} {TRANSFORMERS_REPOSITORY_PATH}") | |
os.system( | |
f"cd {TRANSFORMERS_REPOSITORY_PATH} && git checkout {TRANSFORMERS_REPOSITORY_REVISION}" | |
) | |
st.write("## Convert a HuggingFace model to ONNX") | |
input_model_id = st.text_input( | |
"Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`" | |
) | |
if input_model_id: | |
model_name = ( | |
input_model_id.replace(f"{HF_BASE_URL}/", "") | |
.replace("/", "-") | |
.replace(f"{HF_USERNAME}-", "") | |
.strip() | |
) | |
output_model_id = f"{HF_USERNAME}/{model_name}-ONNX" | |
output_model_url = f"{HF_BASE_URL}/{output_model_id}" | |
api = HfApi(token=HF_TOKEN) | |
repo_exists = api.repo_exists(output_model_id) | |
if repo_exists: | |
st.write("This model has already been converted! 🎉") | |
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary") | |
else: | |
st.write(f"This model will be converted and uploaded to the following URL:") | |
st.code(output_model_url, language="plaintext") | |
start_conversion = st.button(label="Proceed", type="primary") | |
if start_conversion: | |
with st.spinner("Converting model..."): | |
output = subprocess.run( | |
[ | |
"python", | |
"-m", | |
"scripts.convert", | |
"--quantize", | |
"--model_id", | |
input_model_id, | |
], | |
cwd=TRANSFORMERS_REPOSITORY_PATH, | |
capture_output=True, | |
text=True, | |
) | |
model_folder_path = ( | |
f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}" | |
) | |
os.rename( | |
f"{model_folder_path}/onnx/model.onnx", | |
f"{model_folder_path}/onnx/decoder_model_merged.onnx", | |
) | |
os.rename( | |
f"{model_folder_path}/onnx/model_quantized.onnx", | |
f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx", | |
) | |
st.success("Conversion successful!") | |
st.code(output.stderr) | |
with st.spinner("Uploading model..."): | |
repository = api.create_repo( | |
f"{output_model_id}", exist_ok=True, private=False | |
) | |
upload_error_message = None | |
try: | |
api.upload_folder( | |
folder_path=model_folder_path, repo_id=repository.repo_id | |
) | |
except Exception as e: | |
upload_error_message = str(e) | |
os.system(f"rm -rf {model_folder_path}") | |
if upload_error_message: | |
st.error(f"Upload failed: {upload_error_message}") | |
else: | |
st.success(f"Upload successful!") | |
st.write("You can now go and view the model on HuggingFace!") | |
st.link_button( | |
f"Go to {output_model_id}", output_model_url, type="primary" | |
) | |