Felladrin commited on
Commit
1364cbe
·
verified ·
1 Parent(s): 58a0ecb

Refactor the code, making it more robust and maintainable

Browse files
Files changed (1) hide show
  1. app.py +193 -121
app.py CHANGED
@@ -1,139 +1,211 @@
 
1
  import os
2
  import subprocess
3
  import sys
4
- import tarfile
5
- import tempfile
6
- import urllib.request
 
7
 
8
  import streamlit as st
9
  from huggingface_hub import HfApi
10
 
11
- HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN")
12
- HF_USERNAME = (
13
- st.secrets.get("HF_USERNAME")
14
- or os.environ.get("HF_USERNAME")
15
- or os.environ.get("SPACE_AUTHOR_NAME")
16
- )
17
-
18
- TRANSFORMERS_BASE_URL = "https://github.com/xenova/transformers.js/archive/refs"
19
- TRANSFORMERS_REPOSITORY_REVISION = "3.0.0"
20
- TRANSFORMERS_REF_TYPE = (
21
- "tags"
22
- if urllib.request.urlopen(
23
- f"{TRANSFORMERS_BASE_URL}/tags/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
24
- ).getcode()
25
- == 200
26
- else "heads"
27
- )
28
- TRANSFORMERS_REPOSITORY_URL = f"{TRANSFORMERS_BASE_URL}/{TRANSFORMERS_REF_TYPE}/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
29
- TRANSFORMERS_REPOSITORY_PATH = "./transformers.js"
30
- ARCHIVE_PATH = f"./transformers_{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
31
- HF_BASE_URL = "https://huggingface.co"
32
-
33
- if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH):
34
- urllib.request.urlretrieve(TRANSFORMERS_REPOSITORY_URL, ARCHIVE_PATH)
35
-
36
- with tempfile.TemporaryDirectory() as tmp_dir:
37
- with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
38
- tar.extractall(tmp_dir)
39
-
40
- extracted_folder = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
41
-
42
- os.rename(extracted_folder, TRANSFORMERS_REPOSITORY_PATH)
43
-
44
- os.remove(ARCHIVE_PATH)
45
- print("Repository downloaded and extracted successfully.")
46
-
47
- st.write("## Convert a HuggingFace model to ONNX")
48
-
49
- input_model_id = st.text_input(
50
- "Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`"
51
- )
52
-
53
- if input_model_id:
54
- model_name = (
55
- input_model_id.replace(f"{HF_BASE_URL}/", "")
56
- .replace("/", "-")
57
- .replace(f"{HF_USERNAME}-", "")
58
- .strip()
59
  )
60
- output_model_id = f"{HF_USERNAME}/{model_name}-ONNX"
61
- output_model_url = f"{HF_BASE_URL}/{output_model_id}"
62
- api = HfApi(token=HF_TOKEN)
63
- repo_exists = api.repo_exists(output_model_id)
64
-
65
- if repo_exists:
66
- st.write("This model has already been converted! 🎉")
67
- st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
68
- else:
69
- st.write(f"This model will be converted and uploaded to the following URL:")
70
- st.code(output_model_url, language="plaintext")
71
- start_conversion = st.button(label="Proceed", type="primary")
72
-
73
- if start_conversion:
74
- with st.spinner("Converting model..."):
75
- output = subprocess.run(
76
- [
77
- sys.executable,
78
- "-m",
79
- "scripts.convert",
80
- "--quantize",
81
- "--model_id",
82
- input_model_id,
83
- ],
84
- cwd=TRANSFORMERS_REPOSITORY_PATH,
85
- capture_output=True,
86
- text=True,
87
- env={},
88
- )
89
-
90
- # Log the script output
91
- print("### Script Output ###")
92
- print(output.stdout)
93
-
94
- # Log any errors
95
- if output.stderr:
96
- print("### Script Errors ###")
97
- print(output.stderr)
98
-
99
- model_folder_path = (
100
- f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  )
102
 
103
- os.rename(
104
- f"{model_folder_path}/onnx/model.onnx",
105
- f"{model_folder_path}/onnx/decoder_model_merged.onnx",
106
- )
107
- os.rename(
108
- f"{model_folder_path}/onnx/model_quantized.onnx",
109
- f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  )
 
 
 
 
 
111
 
112
- st.success("Conversion successful!")
113
 
114
- st.code(output.stderr)
115
 
116
- with st.spinner("Uploading model..."):
117
- repository = api.create_repo(
118
- f"{output_model_id}", exist_ok=True, private=False
119
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- upload_error_message = None
 
 
122
 
123
- try:
124
- api.upload_folder(
125
- folder_path=model_folder_path, repo_id=repository.repo_id
126
- )
127
- except Exception as e:
128
- upload_error_message = str(e)
129
 
130
- os.system(f"rm -rf {model_folder_path}")
131
 
132
- if upload_error_message:
133
- st.error(f"Upload failed: {upload_error_message}")
134
- else:
135
- st.success(f"Upload successful!")
136
- st.write("You can now go and view the model on HuggingFace!")
137
- st.link_button(
138
- f"Go to {output_model_id}", output_model_url, type="primary"
139
- )
 
1
+ import logging
2
  import os
3
  import subprocess
4
  import sys
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional, Tuple
8
+ from urllib.request import urlopen, urlretrieve
9
 
10
  import streamlit as st
11
  from huggingface_hub import HfApi
12
 
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class Config:
19
+ """Application configuration."""
20
+
21
+ hf_token: str
22
+ hf_username: str
23
+ transformers_version: str = "3.0.0"
24
+ hf_base_url: str = "https://huggingface.co"
25
+ transformers_base_url: str = (
26
+ "https://github.com/xenova/transformers.js/archive/refs"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
+ repo_path: Path = Path("./transformers.js")
29
+
30
+ @classmethod
31
+ def from_env(cls) -> "Config":
32
+ """Create config from environment variables and secrets."""
33
+ hf_token = st.secrets.get("HF_TOKEN") or os.getenv("HF_TOKEN", "")
34
+ hf_username = (
35
+ st.secrets.get("HF_USERNAME")
36
+ or os.getenv("HF_USERNAME")
37
+ or os.getenv("SPACE_AUTHOR_NAME", "")
38
+ )
39
+
40
+ if not hf_token or not hf_username:
41
+ raise ValueError("HF_TOKEN and HF_USERNAME must be set")
42
+
43
+ return cls(hf_token=hf_token, hf_username=hf_username)
44
+
45
+
46
+ class ModelConverter:
47
+ """Handles model conversion and upload operations."""
48
+
49
+ def __init__(self, config: Config):
50
+ self.config = config
51
+ self.api = HfApi(token=config.hf_token)
52
+
53
+ def _get_ref_type(self) -> str:
54
+ """Determine the reference type for the transformers repository."""
55
+ url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz"
56
+ try:
57
+ return "tags" if urlopen(url).getcode() == 200 else "heads"
58
+ except Exception as e:
59
+ logger.warning(f"Failed to check tags, defaulting to heads: {e}")
60
+ return "heads"
61
+
62
+ def setup_repository(self) -> None:
63
+ """Download and setup transformers repository if needed."""
64
+ if self.config.repo_path.exists():
65
+ return
66
+
67
+ ref_type = self._get_ref_type()
68
+ archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz"
69
+ archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz")
70
+
71
+ try:
72
+ urlretrieve(archive_url, archive_path)
73
+ self._extract_archive(archive_path)
74
+ logger.info("Repository downloaded and extracted successfully")
75
+ except Exception as e:
76
+ raise RuntimeError(f"Failed to setup repository: {e}")
77
+ finally:
78
+ archive_path.unlink(missing_ok=True)
79
+
80
+ def _extract_archive(self, archive_path: Path) -> None:
81
+ """Extract the downloaded archive."""
82
+ import tarfile
83
+ import tempfile
84
+
85
+ with tempfile.TemporaryDirectory() as tmp_dir:
86
+ with tarfile.open(archive_path, "r:gz") as tar:
87
+ tar.extractall(tmp_dir)
88
+
89
+ extracted_folder = next(Path(tmp_dir).iterdir())
90
+ extracted_folder.rename(self.config.repo_path)
91
+
92
+ def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
93
+ """Convert the model to ONNX format."""
94
+ try:
95
+ result = subprocess.run(
96
+ [
97
+ sys.executable,
98
+ "-m",
99
+ "scripts.convert",
100
+ "--quantize",
101
+ "--model_id",
102
+ input_model_id,
103
+ ],
104
+ cwd=self.config.repo_path,
105
+ capture_output=True,
106
+ text=True,
107
+ env={},
108
  )
109
 
110
+ if result.returncode != 0:
111
+ return False, result.stderr
112
+
113
+ self._rename_model_files(input_model_id)
114
+ return True, result.stderr
115
+
116
+ except Exception as e:
117
+ return False, str(e)
118
+
119
+ def _rename_model_files(self, input_model_id: str) -> None:
120
+ """Rename the converted model files."""
121
+ model_path = self.config.repo_path / "models" / input_model_id / "onnx"
122
+
123
+ renames = [
124
+ ("model.onnx", "decoder_model_merged.onnx"),
125
+ ("model_quantized.onnx", "decoder_model_merged_quantized.onnx"),
126
+ ]
127
+
128
+ for old_name, new_name in renames:
129
+ (model_path / old_name).rename(model_path / new_name)
130
+
131
+ def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
132
+ """Upload the converted model to Hugging Face."""
133
+ try:
134
+ self.api.create_repo(output_model_id, exist_ok=True, private=False)
135
+ model_folder_path = self.config.repo_path / "models" / input_model_id
136
+
137
+ self.api.upload_folder(
138
+ folder_path=str(model_folder_path), repo_id=output_model_id
139
  )
140
+ return None
141
+ except Exception as e:
142
+ return str(e)
143
+ finally:
144
+ import shutil
145
 
146
+ shutil.rmtree(model_folder_path, ignore_errors=True)
147
 
 
148
 
149
+ def main():
150
+ """Main application entry point."""
151
+ st.write("## Convert a Hugging Face model to ONNX")
152
+
153
+ try:
154
+ config = Config.from_env()
155
+ converter = ModelConverter(config)
156
+ converter.setup_repository()
157
+
158
+ input_model_id = st.text_input(
159
+ "Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"
160
+ )
161
+
162
+ if not input_model_id:
163
+ return
164
+
165
+ model_name = (
166
+ input_model_id.replace(f"{config.hf_base_url}/", "")
167
+ .replace("/", "-")
168
+ .replace(f"{config.hf_username}-", "")
169
+ .strip()
170
+ )
171
+
172
+ output_model_id = f"{config.hf_username}/{model_name}-ONNX"
173
+ output_model_url = f"{config.hf_base_url}/{output_model_id}"
174
+
175
+ if converter.api.repo_exists(output_model_id):
176
+ st.write("This model has already been converted! 🎉")
177
+ st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
178
+ return
179
+
180
+ st.write(f"This model will be converted and uploaded to the following URL:")
181
+ st.code(output_model_url, language="plaintext")
182
+
183
+ if not st.button(label="Proceed", type="primary"):
184
+ return
185
+
186
+ with st.spinner("Converting model..."):
187
+ success, stderr = converter.convert_model(input_model_id)
188
+ if not success:
189
+ st.error(f"Conversion failed: {stderr}")
190
+ return
191
+
192
+ st.success("Conversion successful!")
193
+ st.code(stderr)
194
+
195
+ with st.spinner("Uploading model..."):
196
+ error = converter.upload_model(input_model_id, output_model_id)
197
+ if error:
198
+ st.error(f"Upload failed: {error}")
199
+ return
200
 
201
+ st.success("Upload successful!")
202
+ st.write("You can now go and view the model on Hugging Face!")
203
+ st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
204
 
205
+ except Exception as e:
206
+ logger.exception("Application error")
207
+ st.error(f"An error occurred: {str(e)}")
 
 
 
208
 
 
209
 
210
+ if __name__ == "__main__":
211
+ main()