Spaces:
Running
Running
File size: 4,990 Bytes
08e5ef1 d5eb6e1 383d050 5fd1a0a 7edda8b 2c3ad17 75b770e 08e5ef1 aa85862 ac97e5b 08e5ef1 1fba392 925d15e 08e5ef1 2bede7c c613bb1 ac97e5b 925d15e 7686e09 5b4e988 aa85862 8d4ed6d aa85862 ac97e5b c613bb1 ac97e5b ae9159b ac97e5b 2c3ad17 7c36326 aa85862 5696fee af46da6 eefa44d af46da6 eefa44d 9781999 eefa44d 21f9f87 d5eb6e1 9781999 d5eb6e1 9781999 5b4e988 9781999 00dc59f 2bede7c 00dc59f 098f871 ec000c3 3ad22ce 4c4c78d 3ad22ce 4c4c78d 098f871 3ad22ce 098f871 4c4c78d 892a74e 3ad22ce 4c4c78d 3ad22ce 098f871 c360795 3ad22ce 2bede7c 925d15e 098f871 925d15e b31944c 925d15e 2bede7c c360795 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import tempfile
os.environ["HF_HUB_CACHE"] = "cache"
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
import gradio as gr
from huggingface_hub import HfApi
from huggingface_hub import whoami
from huggingface_hub import ModelCard
from huggingface_hub import scan_cache_dir
from huggingface_hub import logging
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from apscheduler.schedulers.background import BackgroundScheduler
from textwrap import dedent
import mlx_lm
from mlx_lm import convert
HF_TOKEN = os.environ.get("HF_TOKEN")
def clear_hf_cache_space():
scan = scan_cache_dir()
to_delete = []
for repo in scan.repos:
if repo.repo_type == "model":
to_delete.extend([rev.commit_hash for rev in repo.revisions])
scan.delete_revisions(*to_delete).execute()
print("Cache has been cleared")
def upload_to_hub(path, upload_repo, hf_path, token):
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.data.base_model = hf_path
card.text = dedent(
f"""
# {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{mlx_lm.__version__}**.
## Use with mlx
```bash
pip install mlx-lm
```
```python
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
"""
)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi(token=token)
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
def process_model(model_id, q_method, oauth_token: gr.OAuthToken | None):
if oauth_token.token is None:
raise ValueError("You must be logged in to use MLX-my-repo")
model_name = model_id.split('/')[-1]
print(model_name)
username = whoami(oauth_token.token)["name"]
print(username)
try:
upload_repo = username + "/" + model_name + "-mlx"
print(upload_repo)
with tempfile.TemporaryDirectory(dir="converted") as tmpdir:
# The target dir must not exist
mlx_path = os.path.join(tmpdir, "mlx")
convert(model_id, mlx_path=mlx_path, quantize=True)
print("Conversion done")
upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, token=oauth_token.token)
print("Upload done")
return (
f'Find your repo <a href="https://hf.co/{upload_repo}" target="_blank" style="text-decoration:underline">here</a>',
"llama.png",
)
except Exception as e:
return (f"Error: {e}", "error.png")
finally:
clear_hf_cache_space()
print("Folder cleaned up successfully!")
css="""/* Custom CSS to allow scrolling */
.gradio-container {overflow-y: auto;}
"""
# Create Gradio interface
with gr.Blocks(css=css) as demo:
gr.Markdown("You must be logged in to use MLX-my-repo.")
gr.LoginButton(min_width=250)
model_id = HuggingfaceHubSearch(
label="Hub Model ID",
placeholder="Search for model id on Huggingface",
search_type="model",
)
q_method = gr.Dropdown(
["Q4", "Q8"],
label="Quantization Method",
info="MLX quantization type",
value="Q4",
filterable=False,
visible=True
)
iface = gr.Interface(
fn=process_model,
inputs=[
model_id,
q_method,
],
outputs=[
gr.Markdown(label="output"),
gr.Image(show_label=False),
],
title="Create your own MLX Quants, blazingly fast ⚡!",
description="The space takes an HF repo as an input, quantizes it and creates a Public/ Private repo containing the selected quant under your HF user namespace.",
api_name=False
)
def restart_space():
HfApi().restart_space(repo_id="reach-vb/mlx-my-repo", token=HF_TOKEN, factory_reboot=True)
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=21600)
scheduler.start()
# Launch the interface
demo.queue(default_concurrency_limit=1, max_size=5).launch(debug=True, show_api=False) |