reach-vb HF staff pcuenq HF staff commited on
Commit
173d502
1 Parent(s): af46da6
Files changed (3) hide show
  1. app.py +22 -28
  2. cache/.keep +0 -0
  3. converted/.keep +0 -0
app.py CHANGED
@@ -1,20 +1,17 @@
1
  import os
2
- import shutil
3
- import subprocess
4
- import signal
5
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
6
  import gradio as gr
7
 
8
- from huggingface_hub import create_repo, HfApi
9
- from huggingface_hub import snapshot_download
10
  from huggingface_hub import whoami
11
  from huggingface_hub import ModelCard
12
- from huggingface_hub import login
13
  from huggingface_hub import scan_cache_dir
14
  from huggingface_hub import logging
15
 
16
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
17
-
18
  from apscheduler.schedulers.background import BackgroundScheduler
19
 
20
  from textwrap import dedent
@@ -22,23 +19,24 @@ from textwrap import dedent
22
  import mlx_lm
23
  from mlx_lm import convert
24
 
25
- from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
26
-
27
  HF_TOKEN = os.environ.get("HF_TOKEN")
28
 
 
 
 
 
 
 
29
  def clear_hf_cache_space():
30
  scan = scan_cache_dir()
31
  to_delete = []
32
  for repo in scan.repos:
33
  if repo.repo_type == "model":
34
- to_delete.append([rev.commit_hash for rev in repo.revisions])
35
-
36
- scan.delete_revisions(to_delete)
37
-
38
  print("Cache has been cleared")
39
 
40
  def upload_to_hub(path, upload_repo, hf_path, token):
41
-
42
  card = ModelCard.load(hf_path)
43
  card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
44
  card.data.base_model = hf_path
@@ -86,33 +84,29 @@ def upload_to_hub(path, upload_repo, hf_path, token):
86
  )
87
  print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
88
 
89
- def process_model(model_id, q_method,oauth_token: gr.OAuthToken | None):
90
-
91
  if oauth_token.token is None:
92
  raise ValueError("You must be logged in to use MLX-my-repo")
93
 
94
  model_name = model_id.split('/')[-1]
95
- print(model_name)
96
  username = whoami(oauth_token.token)["name"]
97
- print(username)
98
-
99
- # login(token=oauth_token.token, add_to_git_credential=True)
100
-
101
  try:
102
- upload_repo = username + "/" + model_name + "-mlx"
103
  print(upload_repo)
104
- convert(model_id, quantize=True)
105
- print("Conversion done")
106
- upload_to_hub(path="mlx_model", upload_repo=upload_repo, hf_path=repo_id, token=oauth_token.token)
107
- print("Upload done")
 
 
 
108
  return (
109
- f'Find your repo <a href=\'{new_repo_url}\' target="_blank" style="text-decoration:underline">here</a>',
110
  "llama.png",
111
  )
112
  except Exception as e:
113
  return (f"Error: {e}", "error.png")
114
  finally:
115
- shutil.rmtree("mlx_model", ignore_errors=True)
116
  clear_hf_cache_space()
117
  print("Folder cleaned up successfully!")
118
 
 
1
  import os
2
+ import tempfile
3
+
4
+ os.environ["HF_HUB_CACHE"] = "cache"
5
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
6
  import gradio as gr
7
 
8
+ from huggingface_hub import HfApi
 
9
  from huggingface_hub import whoami
10
  from huggingface_hub import ModelCard
 
11
  from huggingface_hub import scan_cache_dir
12
  from huggingface_hub import logging
13
 
14
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
 
15
  from apscheduler.schedulers.background import BackgroundScheduler
16
 
17
  from textwrap import dedent
 
19
  import mlx_lm
20
  from mlx_lm import convert
21
 
 
 
22
  HF_TOKEN = os.environ.get("HF_TOKEN")
23
 
24
+ # I'm not sure if we need to add more stuff here
25
+ QUANT_PARAMS = {
26
+ "Q4": 4,
27
+ "Q8": 8,
28
+ }
29
+
30
  def clear_hf_cache_space():
31
  scan = scan_cache_dir()
32
  to_delete = []
33
  for repo in scan.repos:
34
  if repo.repo_type == "model":
35
+ to_delete.extend([rev.commit_hash for rev in repo.revisions])
36
+ scan.delete_revisions(*to_delete).execute()
 
 
37
  print("Cache has been cleared")
38
 
39
  def upload_to_hub(path, upload_repo, hf_path, token):
 
40
  card = ModelCard.load(hf_path)
41
  card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
42
  card.data.base_model = hf_path
 
84
  )
85
  print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
86
 
87
+ def process_model(model_id, q_method, oauth_token: gr.OAuthToken | None):
 
88
  if oauth_token.token is None:
89
  raise ValueError("You must be logged in to use MLX-my-repo")
90
 
91
  model_name = model_id.split('/')[-1]
 
92
  username = whoami(oauth_token.token)["name"]
 
 
 
 
93
  try:
94
+ upload_repo = f"{username}/{model_name}-{q_method}-mlx"
95
  print(upload_repo)
96
+ with tempfile.TemporaryDirectory(dir="converted") as tmpdir:
97
+ # The target dir must not exist
98
+ mlx_path = os.path.join(tmpdir, "mlx")
99
+ convert(model_id, mlx_path=mlx_path, quantize=True, q_bits=QUANT_PARAMS[q_method])
100
+ print("Conversion done")
101
+ upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, token=oauth_token.token)
102
+ print("Upload done")
103
  return (
104
+ f'Find your repo <a href="https://hf.co/{upload_repo}" target="_blank" style="text-decoration:underline">here</a>',
105
  "llama.png",
106
  )
107
  except Exception as e:
108
  return (f"Error: {e}", "error.png")
109
  finally:
 
110
  clear_hf_cache_space()
111
  print("Folder cleaned up successfully!")
112
 
cache/.keep ADDED
File without changes
converted/.keep ADDED
File without changes