urroxyz commited on
Commit
2f5e58b
·
verified ·
1 Parent(s): 0ceadd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -65
app.py CHANGED
@@ -2,7 +2,6 @@ import logging
2
  import os
3
  import subprocess
4
  import sys
5
- import warnings
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
  from typing import Optional, Tuple
@@ -10,11 +9,6 @@ from urllib.request import urlopen, urlretrieve
10
 
11
  import streamlit as st
12
  from huggingface_hub import HfApi, whoami
13
- from torch.jit import TracerWarning
14
- from transformers import AutoConfig, GenerationConfig
15
-
16
- # Suppress local TorchScript tracer warnings
17
- warnings.filterwarnings("ignore", category=TracerWarning)
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
@@ -23,6 +17,7 @@ logger = logging.getLogger(__name__)
23
  @dataclass
24
  class Config:
25
  """Application configuration."""
 
26
  hf_token: str
27
  hf_username: str
28
  transformers_version: str = "3.5.0"
@@ -44,8 +39,10 @@ class Config:
44
  os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
45
  )
46
  hf_token = user_token or system_token
 
47
  if not hf_token:
48
  raise ValueError("HF_TOKEN must be set")
 
49
  return cls(hf_token=hf_token, hf_username=hf_username)
50
 
51
 
@@ -66,12 +63,14 @@ class ModelConverter:
66
  return "heads"
67
 
68
  def setup_repository(self) -> None:
69
- """Download and setup transformers.js repo if needed."""
70
  if self.config.repo_path.exists():
71
  return
 
72
  ref_type = self._get_ref_type()
73
  archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz"
74
  archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz")
 
75
  try:
76
  urlretrieve(archive_url, archive_path)
77
  self._extract_archive(archive_path)
@@ -83,38 +82,19 @@ class ModelConverter:
83
 
84
  def _extract_archive(self, archive_path: Path) -> None:
85
  """Extract the downloaded archive."""
86
- import tarfile, tempfile
 
 
87
  with tempfile.TemporaryDirectory() as tmp_dir:
88
  with tarfile.open(archive_path, "r:gz") as tar:
89
  tar.extractall(tmp_dir)
 
90
  extracted_folder = next(Path(tmp_dir).iterdir())
91
  extracted_folder.rename(self.config.repo_path)
92
 
93
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
94
- """
95
- Convert the model to ONNX, always exporting attention maps.
96
- Relocate generation params, suppress tracer warnings, and
97
- filter out relocation/tracer warnings from stderr.
98
- """
99
  try:
100
- # 1. Prepare a local folder for config tweaks
101
- model_dir = self.config.repo_path / "models" / input_model_id
102
- model_dir.mkdir(parents=True, exist_ok=True)
103
-
104
- # 2. Move any generation parameters into generation_config.json
105
- base_cfg = AutoConfig.from_pretrained(input_model_id)
106
- gen_cfg = GenerationConfig.from_model_config(base_cfg)
107
- for k in gen_cfg.to_dict():
108
- if hasattr(base_cfg, k):
109
- setattr(base_cfg, k, None)
110
- base_cfg.save_pretrained(model_dir)
111
- gen_cfg.save_pretrained(model_dir)
112
-
113
- # 3. Set verbose logging via env var (no --debug flag)
114
- env = os.environ.copy()
115
- env["TRANSFORMERS_VERBOSITY"] = "debug"
116
-
117
- # 4. Build and run the conversion command
118
  cmd = [
119
  sys.executable,
120
  "-m", "scripts.convert",
@@ -128,43 +108,41 @@ class ModelConverter:
128
  cwd=self.config.repo_path,
129
  capture_output=True,
130
  text=True,
131
- env=env,
132
  )
133
 
134
- # 5. Filter out spurious warnings from stderr
135
- filtered = []
136
- for ln in result.stderr.splitlines():
137
- if ln.startswith("Moving the following attributes"):
138
- continue
139
- if "TracerWarning" in ln:
140
- continue
141
- filtered.append(ln)
142
- stderr = "\n".join(filtered)
143
-
144
  if result.returncode != 0:
145
- return False, stderr
146
- return True, stderr
 
147
 
148
  except Exception as e:
149
  return False, str(e)
150
 
151
  def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
152
- """Upload the converted model to Hugging Face Hub."""
153
- model_folder = self.config.repo_path / "models" / input_model_id
 
154
  try:
155
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
156
- readme_path = model_folder / "README.md"
157
- if not readme_path.exists():
158
- readme_path.write_text(self.generate_readme(input_model_id))
159
- self.api.upload_folder(folder_path=str(model_folder), repo_id=output_model_id)
 
 
 
 
 
 
160
  return None
161
  except Exception as e:
162
  return str(e)
163
  finally:
164
  import shutil
165
- shutil.rmtree(model_folder, ignore_errors=True)
166
 
167
- def generate_readme(self, imi: str) -> str:
168
  return (
169
  "---\n"
170
  "library_name: transformers.js\n"
@@ -173,13 +151,14 @@ class ModelConverter:
173
  "---\n\n"
174
  f"# {imi.split('/')[-1]} (ONNX)\n\n"
175
  f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
176
- "Converted with attention maps and verbose export logs.\n"
 
177
  )
178
 
179
 
180
  def main():
181
- """Streamlit application entry point."""
182
- st.write("## Convert a Hugging Face model to ONNX (with attentions & debug logs)")
183
 
184
  try:
185
  config = Config.from_env()
@@ -187,19 +166,21 @@ def main():
187
  converter.setup_repository()
188
 
189
  input_model_id = st.text_input(
190
- "Enter the Hugging Face model ID to convert, e.g. `EleutherAI/pythia-14m`"
191
  )
192
  if not input_model_id:
193
  return
194
 
195
  st.text_input(
196
- "Optional: Your Hugging Face write token (for uploading to your namespace).",
197
  type="password",
198
  key="user_hf_token",
199
  )
200
 
201
  if config.hf_username == input_model_id.split("/")[0]:
202
- same_repo = st.checkbox("Upload ONNX weights to the same repository?")
 
 
203
  else:
204
  same_repo = False
205
 
@@ -208,14 +189,20 @@ def main():
208
  if not same_repo:
209
  output_model_id += "-ONNX"
210
 
211
- output_url = f"{config.hf_base_url}/{output_model_id}"
 
 
 
 
 
 
212
  st.write("Destination repository:")
213
- st.code(output_url, language="plaintext")
214
 
215
- if not st.button("Proceed", type="primary"):
216
  return
217
 
218
- with st.spinner("Converting model…"):
219
  success, stderr = converter.convert_model(input_model_id)
220
  if not success:
221
  st.error(f"Conversion failed: {stderr}")
@@ -229,12 +216,14 @@ def main():
229
  st.error(f"Upload failed: {error}")
230
  return
231
  st.success("Upload successful!")
232
- st.link_button(f"Go to {output_model_id}", output_url, type="primary")
 
233
 
234
  except Exception as e:
235
  logger.exception("Application error")
236
- st.error(f"An error occurred: {e}")
237
 
238
 
239
  if __name__ == "__main__":
240
- main()
 
 
2
  import os
3
  import subprocess
4
  import sys
 
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
  from typing import Optional, Tuple
 
9
 
10
  import streamlit as st
11
  from huggingface_hub import HfApi, whoami
 
 
 
 
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
 
17
  @dataclass
18
  class Config:
19
  """Application configuration."""
20
+
21
  hf_token: str
22
  hf_username: str
23
  transformers_version: str = "3.5.0"
 
39
  os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
40
  )
41
  hf_token = user_token or system_token
42
+
43
  if not hf_token:
44
  raise ValueError("HF_TOKEN must be set")
45
+
46
  return cls(hf_token=hf_token, hf_username=hf_username)
47
 
48
 
 
63
  return "heads"
64
 
65
  def setup_repository(self) -> None:
66
+ """Download and setup transformers repository if needed."""
67
  if self.config.repo_path.exists():
68
  return
69
+
70
  ref_type = self._get_ref_type()
71
  archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz"
72
  archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz")
73
+
74
  try:
75
  urlretrieve(archive_url, archive_path)
76
  self._extract_archive(archive_path)
 
82
 
83
  def _extract_archive(self, archive_path: Path) -> None:
84
  """Extract the downloaded archive."""
85
+ import tarfile
86
+ import tempfile
87
+
88
  with tempfile.TemporaryDirectory() as tmp_dir:
89
  with tarfile.open(archive_path, "r:gz") as tar:
90
  tar.extractall(tmp_dir)
91
+
92
  extracted_folder = next(Path(tmp_dir).iterdir())
93
  extracted_folder.rename(self.config.repo_path)
94
 
95
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
96
+ """Convert the model to ONNX format, always exporting attention maps."""
 
 
 
 
97
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  cmd = [
99
  sys.executable,
100
  "-m", "scripts.convert",
 
108
  cwd=self.config.repo_path,
109
  capture_output=True,
110
  text=True,
111
+ env={},
112
  )
113
 
 
 
 
 
 
 
 
 
 
 
114
  if result.returncode != 0:
115
+ return False, result.stderr
116
+
117
+ return True, result.stderr
118
 
119
  except Exception as e:
120
  return False, str(e)
121
 
122
  def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
123
+ """Upload the converted model to Hugging Face."""
124
+ model_folder_path = self.config.repo_path / "models" / input_model_id
125
+
126
  try:
127
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
128
+
129
+ readme_path = f"{model_folder_path}/README.md"
130
+
131
+ if not os.path.exists(readme_path):
132
+ with open(readme_path, "w") as file:
133
+ file.write(self.generate_readme(input_model_id))
134
+
135
+ self.api.upload_folder(
136
+ folder_path=str(model_folder_path), repo_id=output_model_id
137
+ )
138
  return None
139
  except Exception as e:
140
  return str(e)
141
  finally:
142
  import shutil
143
+ shutil.rmtree(model_folder_path, ignore_errors=True)
144
 
145
+ def generate_readme(self, imi: str):
146
  return (
147
  "---\n"
148
  "library_name: transformers.js\n"
 
151
  "---\n\n"
152
  f"# {imi.split('/')[-1]} (ONNX)\n\n"
153
  f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
154
+ "It was automatically converted and uploaded using "
155
+ "[this space](https://huggingface.co/spaces/onnx-community/convert-to-onnx).\n"
156
  )
157
 
158
 
159
  def main():
160
+ """Main application entry point."""
161
+ st.write("## Convert a Hugging Face model to ONNX (with attentions)")
162
 
163
  try:
164
  config = Config.from_env()
 
166
  converter.setup_repository()
167
 
168
  input_model_id = st.text_input(
169
+ "Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"
170
  )
171
  if not input_model_id:
172
  return
173
 
174
  st.text_input(
175
+ "Optional: Your Hugging Face write token. Fill it if you want to upload under your account.",
176
  type="password",
177
  key="user_hf_token",
178
  )
179
 
180
  if config.hf_username == input_model_id.split("/")[0]:
181
+ same_repo = st.checkbox(
182
+ "Upload ONNX weights to the same repository?"
183
+ )
184
  else:
185
  same_repo = False
186
 
 
189
  if not same_repo:
190
  output_model_id += "-ONNX"
191
 
192
+ output_model_url = f"{config.hf_base_url}/{output_model_id}"
193
+
194
+ if not same_repo and converter.api.repo_exists(output_model_id):
195
+ st.write("This model has already been converted! 🎉")
196
+ st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
197
+ return
198
+
199
  st.write("Destination repository:")
200
+ st.code(output_model_url, language="plaintext")
201
 
202
+ if not st.button(label="Proceed", type="primary"):
203
  return
204
 
205
+ with st.spinner("Converting model (including attention maps)…"):
206
  success, stderr = converter.convert_model(input_model_id)
207
  if not success:
208
  st.error(f"Conversion failed: {stderr}")
 
216
  st.error(f"Upload failed: {error}")
217
  return
218
  st.success("Upload successful!")
219
+ st.write("You can now view the model on Hugging Face:")
220
+ st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
221
 
222
  except Exception as e:
223
  logger.exception("Application error")
224
+ st.error(f"An error occurred: {str(e)}")
225
 
226
 
227
  if __name__ == "__main__":
228
+ main()
229
+