urroxyz commited on
Commit
0ceadd8
·
verified ·
1 Parent(s): 21d5033

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -37
app.py CHANGED
@@ -13,7 +13,7 @@ from huggingface_hub import HfApi, whoami
13
  from torch.jit import TracerWarning
14
  from transformers import AutoConfig, GenerationConfig
15
 
16
- # Suppress local TorchScript TracerWarnings
17
  warnings.filterwarnings("ignore", category=TracerWarning)
18
 
19
  logging.basicConfig(level=logging.INFO)
@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
22
 
23
  @dataclass
24
  class Config:
 
25
  hf_token: str
26
  hf_username: str
27
  transformers_version: str = "3.5.0"
@@ -33,6 +34,7 @@ class Config:
33
 
34
  @classmethod
35
  def from_env(cls) -> "Config":
 
36
  system_token = st.secrets.get("HF_TOKEN")
37
  user_token = st.session_state.get("user_hf_token")
38
  if user_token:
@@ -48,11 +50,14 @@ class Config:
48
 
49
 
50
  class ModelConverter:
 
 
51
  def __init__(self, config: Config):
52
  self.config = config
53
  self.api = HfApi(token=config.hf_token)
54
 
55
  def _get_ref_type(self) -> str:
 
56
  url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz"
57
  try:
58
  return "tags" if urlopen(url).getcode() == 200 else "heads"
@@ -61,6 +66,7 @@ class ModelConverter:
61
  return "heads"
62
 
63
  def setup_repository(self) -> None:
 
64
  if self.config.repo_path.exists():
65
  return
66
  ref_type = self._get_ref_type()
@@ -76,30 +82,39 @@ class ModelConverter:
76
  archive_path.unlink(missing_ok=True)
77
 
78
  def _extract_archive(self, archive_path: Path) -> None:
 
79
  import tarfile, tempfile
80
  with tempfile.TemporaryDirectory() as tmp_dir:
81
  with tarfile.open(archive_path, "r:gz") as tar:
82
  tar.extractall(tmp_dir)
83
- next(Path(tmp_dir).iterdir()).rename(self.config.repo_path)
 
84
 
85
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
 
 
 
 
 
86
  try:
87
- # Prepare model dir
88
  model_dir = self.config.repo_path / "models" / input_model_id
89
  model_dir.mkdir(parents=True, exist_ok=True)
90
- # Relocate generation params
 
91
  base_cfg = AutoConfig.from_pretrained(input_model_id)
92
  gen_cfg = GenerationConfig.from_model_config(base_cfg)
93
  for k in gen_cfg.to_dict():
94
- if hasattr(base_cfg, k): setattr(base_cfg, k, None)
 
95
  base_cfg.save_pretrained(model_dir)
96
  gen_cfg.save_pretrained(model_dir)
97
- # Set verbose logging
 
98
  env = os.environ.copy()
99
  env["TRANSFORMERS_VERBOSITY"] = "debug"
100
- # Build command with debug
101
- # Build conversion command
102
- # Rely on TRANSFORMERS_VERBOSITY for logging; remove unsupported debug flag
103
  cmd = [
104
  sys.executable,
105
  "-m", "scripts.convert",
@@ -107,7 +122,6 @@ class ModelConverter:
107
  "--trust_remote_code",
108
  "--model_id", input_model_id,
109
  "--output_attentions",
110
- "--debug"
111
  ]
112
  result = subprocess.run(
113
  cmd,
@@ -116,28 +130,39 @@ class ModelConverter:
116
  text=True,
117
  env=env,
118
  )
119
- # Filter warnings
120
- filtered = [ln for ln in result.stderr.splitlines() if not ln.startswith("Moving the following attributes") and "TracerWarning" not in ln]
 
 
 
 
 
 
 
121
  stderr = "\n".join(filtered)
 
122
  if result.returncode != 0:
123
  return False, stderr
124
  return True, stderr
 
125
  except Exception as e:
126
  return False, str(e)
127
 
128
  def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
 
129
  model_folder = self.config.repo_path / "models" / input_model_id
130
  try:
131
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
132
- readme = model_folder / "README.md"
133
- if not readme.exists():
134
- readme.write_text(self.generate_readme(input_model_id))
135
  self.api.upload_folder(folder_path=str(model_folder), repo_id=output_model_id)
136
  return None
137
  except Exception as e:
138
  return str(e)
139
  finally:
140
- import shutil; shutil.rmtree(model_folder, ignore_errors=True)
 
141
 
142
  def generate_readme(self, imi: str) -> str:
143
  return (
@@ -148,31 +173,68 @@ class ModelConverter:
148
  "---\n\n"
149
  f"# {imi.split('/')[-1]} (ONNX)\n\n"
150
  f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
151
- "Converted with debug logs and attention maps.\n"
152
  )
153
 
 
154
  def main():
155
- st.write("## Convert a Hugging Face model to ONNX (with debug)")
 
 
156
  try:
157
  config = Config.from_env()
158
- conv = ModelConverter(config)
159
- conv.setup_repository()
160
- input_id = st.text_input("Model ID e.g. EleutherAI/pythia-14m")
161
- if not input_id: return
162
- st.text_input("HF write token (optional)", type="password", key="user_hf_token")
163
- same = st.checkbox("Upload to same repo?", value=False) if config.hf_username == input_id.split("/")[0] else False
164
- name = input_id.split("/")[-1]; out = f"{config.hf_username}/{name}" + ("" if same else "-ONNX")
165
- url = f"{config.hf_base_url}/{out}"; st.code(url)
166
- if not st.button("Proceed"): return
167
- with st.spinner("Converting (debug)..."):
168
- ok, err = conv.convert_model(input_id)
169
- if not ok: st.error(f"Conversion failed: {err}"); return
170
- st.success("Conversion successful!"); st.code(err)
171
- with st.spinner("Uploading..."):
172
- err2 = conv.upload_model(input_id, out)
173
- if err2: st.error(f"Upload failed: {err2}"); return
174
- st.success("Upload successful!"); st.link_button(f"Go to {out}", url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  except Exception as e:
176
- logger.exception(e); st.error(f"Error: {e}")
 
 
177
 
178
- if __name__ == "__main__": main()
 
 
13
  from torch.jit import TracerWarning
14
  from transformers import AutoConfig, GenerationConfig
15
 
16
+ # Suppress local TorchScript tracer warnings
17
  warnings.filterwarnings("ignore", category=TracerWarning)
18
 
19
  logging.basicConfig(level=logging.INFO)
 
22
 
23
  @dataclass
24
  class Config:
25
+ """Application configuration."""
26
  hf_token: str
27
  hf_username: str
28
  transformers_version: str = "3.5.0"
 
34
 
35
  @classmethod
36
  def from_env(cls) -> "Config":
37
+ """Create config from environment variables and secrets."""
38
  system_token = st.secrets.get("HF_TOKEN")
39
  user_token = st.session_state.get("user_hf_token")
40
  if user_token:
 
50
 
51
 
52
  class ModelConverter:
53
+ """Handles model conversion and upload operations."""
54
+
55
  def __init__(self, config: Config):
56
  self.config = config
57
  self.api = HfApi(token=config.hf_token)
58
 
59
  def _get_ref_type(self) -> str:
60
+ """Determine the reference type for the transformers repository."""
61
  url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz"
62
  try:
63
  return "tags" if urlopen(url).getcode() == 200 else "heads"
 
66
  return "heads"
67
 
68
  def setup_repository(self) -> None:
69
+ """Download and setup transformers.js repo if needed."""
70
  if self.config.repo_path.exists():
71
  return
72
  ref_type = self._get_ref_type()
 
82
  archive_path.unlink(missing_ok=True)
83
 
84
  def _extract_archive(self, archive_path: Path) -> None:
85
+ """Extract the downloaded archive."""
86
  import tarfile, tempfile
87
  with tempfile.TemporaryDirectory() as tmp_dir:
88
  with tarfile.open(archive_path, "r:gz") as tar:
89
  tar.extractall(tmp_dir)
90
+ extracted_folder = next(Path(tmp_dir).iterdir())
91
+ extracted_folder.rename(self.config.repo_path)
92
 
93
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
94
+ """
95
+ Convert the model to ONNX, always exporting attention maps.
96
+ Relocate generation params, suppress tracer warnings, and
97
+ filter out relocation/tracer warnings from stderr.
98
+ """
99
  try:
100
+ # 1. Prepare a local folder for config tweaks
101
  model_dir = self.config.repo_path / "models" / input_model_id
102
  model_dir.mkdir(parents=True, exist_ok=True)
103
+
104
+ # 2. Move any generation parameters into generation_config.json
105
  base_cfg = AutoConfig.from_pretrained(input_model_id)
106
  gen_cfg = GenerationConfig.from_model_config(base_cfg)
107
  for k in gen_cfg.to_dict():
108
+ if hasattr(base_cfg, k):
109
+ setattr(base_cfg, k, None)
110
  base_cfg.save_pretrained(model_dir)
111
  gen_cfg.save_pretrained(model_dir)
112
+
113
+ # 3. Set verbose logging via env var (no --debug flag)
114
  env = os.environ.copy()
115
  env["TRANSFORMERS_VERBOSITY"] = "debug"
116
+
117
+ # 4. Build and run the conversion command
 
118
  cmd = [
119
  sys.executable,
120
  "-m", "scripts.convert",
 
122
  "--trust_remote_code",
123
  "--model_id", input_model_id,
124
  "--output_attentions",
 
125
  ]
126
  result = subprocess.run(
127
  cmd,
 
130
  text=True,
131
  env=env,
132
  )
133
+
134
+ # 5. Filter out spurious warnings from stderr
135
+ filtered = []
136
+ for ln in result.stderr.splitlines():
137
+ if ln.startswith("Moving the following attributes"):
138
+ continue
139
+ if "TracerWarning" in ln:
140
+ continue
141
+ filtered.append(ln)
142
  stderr = "\n".join(filtered)
143
+
144
  if result.returncode != 0:
145
  return False, stderr
146
  return True, stderr
147
+
148
  except Exception as e:
149
  return False, str(e)
150
 
151
  def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
152
+ """Upload the converted model to Hugging Face Hub."""
153
  model_folder = self.config.repo_path / "models" / input_model_id
154
  try:
155
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
156
+ readme_path = model_folder / "README.md"
157
+ if not readme_path.exists():
158
+ readme_path.write_text(self.generate_readme(input_model_id))
159
  self.api.upload_folder(folder_path=str(model_folder), repo_id=output_model_id)
160
  return None
161
  except Exception as e:
162
  return str(e)
163
  finally:
164
+ import shutil
165
+ shutil.rmtree(model_folder, ignore_errors=True)
166
 
167
  def generate_readme(self, imi: str) -> str:
168
  return (
 
173
  "---\n\n"
174
  f"# {imi.split('/')[-1]} (ONNX)\n\n"
175
  f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
176
+ "Converted with attention maps and verbose export logs.\n"
177
  )
178
 
179
+
180
  def main():
181
+ """Streamlit application entry point."""
182
+ st.write("## Convert a Hugging Face model to ONNX (with attentions & debug logs)")
183
+
184
  try:
185
  config = Config.from_env()
186
+ converter = ModelConverter(config)
187
+ converter.setup_repository()
188
+
189
+ input_model_id = st.text_input(
190
+ "Enter the Hugging Face model ID to convert, e.g. `EleutherAI/pythia-14m`"
191
+ )
192
+ if not input_model_id:
193
+ return
194
+
195
+ st.text_input(
196
+ "Optional: Your Hugging Face write token (for uploading to your namespace).",
197
+ type="password",
198
+ key="user_hf_token",
199
+ )
200
+
201
+ if config.hf_username == input_model_id.split("/")[0]:
202
+ same_repo = st.checkbox("Upload ONNX weights to the same repository?")
203
+ else:
204
+ same_repo = False
205
+
206
+ model_name = input_model_id.split("/")[-1]
207
+ output_model_id = f"{config.hf_username}/{model_name}"
208
+ if not same_repo:
209
+ output_model_id += "-ONNX"
210
+
211
+ output_url = f"{config.hf_base_url}/{output_model_id}"
212
+ st.write("Destination repository:")
213
+ st.code(output_url, language="plaintext")
214
+
215
+ if not st.button("Proceed", type="primary"):
216
+ return
217
+
218
+ with st.spinner("Converting model…"):
219
+ success, stderr = converter.convert_model(input_model_id)
220
+ if not success:
221
+ st.error(f"Conversion failed: {stderr}")
222
+ return
223
+ st.success("Conversion successful!")
224
+ st.code(stderr)
225
+
226
+ with st.spinner("Uploading model…"):
227
+ error = converter.upload_model(input_model_id, output_model_id)
228
+ if error:
229
+ st.error(f"Upload failed: {error}")
230
+ return
231
+ st.success("Upload successful!")
232
+ st.link_button(f"Go to {output_model_id}", output_url, type="primary")
233
+
234
  except Exception as e:
235
+ logger.exception("Application error")
236
+ st.error(f"An error occurred: {e}")
237
+
238
 
239
+ if __name__ == "__main__":
240
+ main()