Update app.py
Browse files
app.py
CHANGED
|
@@ -5,6 +5,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
|
| 5 |
import gradio as gr
|
| 6 |
import tempfile
|
| 7 |
import torch
|
|
|
|
| 8 |
|
| 9 |
from huggingface_hub import HfApi, ModelCard, whoami
|
| 10 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
|
@@ -74,17 +75,32 @@ def run_command(command):
|
|
| 74 |
|
| 75 |
###########
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
|
| 79 |
if oauth_token is None or oauth_token.token is None:
|
| 80 |
raise gr.Error("You must be logged in")
|
| 81 |
model_name = ft_model_id.split('/')[-1]
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
if not os.path.exists("outputs"):
|
| 84 |
os.makedirs("outputs")
|
| 85 |
|
| 86 |
try:
|
| 87 |
api = HfApi(token=oauth_token.token)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
|
| 90 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -102,7 +118,7 @@ def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo,
|
|
| 102 |
print("output_stdout", output_stdout)
|
| 103 |
print("output_stderr", output_stderr)
|
| 104 |
if returncode != 0:
|
| 105 |
-
raise Exception(f"Error converting to LoRA PEFT {
|
| 106 |
print("Model converted to LoRA PEFT successfully!")
|
| 107 |
print(f"Converted model path: {outputdir}")
|
| 108 |
|
|
@@ -146,8 +162,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 146 |
)
|
| 147 |
|
| 148 |
base_model_id = HuggingfaceHubSearch(
|
| 149 |
-
label="Base model repository",
|
| 150 |
-
placeholder="
|
| 151 |
search_type="model",
|
| 152 |
)
|
| 153 |
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import tempfile
|
| 7 |
import torch
|
| 8 |
+
import requests
|
| 9 |
|
| 10 |
from huggingface_hub import HfApi, ModelCard, whoami
|
| 11 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
|
|
|
| 75 |
|
| 76 |
###########
|
| 77 |
|
| 78 |
+
def guess_base_model(ft_model_id):
|
| 79 |
+
res = requests.get(f"https://huggingface.co/api/models/{ft_model_id}")
|
| 80 |
+
res = res.json()
|
| 81 |
+
for tag in res["tags"]:
|
| 82 |
+
if tag.startswith("base_model:"):
|
| 83 |
+
return tag.split(":")[-1]
|
| 84 |
+
raise Exception("Cannot guess the base model, please enter it manually")
|
| 85 |
+
|
| 86 |
|
| 87 |
def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
|
| 88 |
if oauth_token is None or oauth_token.token is None:
|
| 89 |
raise gr.Error("You must be logged in")
|
| 90 |
model_name = ft_model_id.split('/')[-1]
|
| 91 |
|
| 92 |
+
# validate the oauth token
|
| 93 |
+
whoami(oauth_token.token)
|
| 94 |
+
|
| 95 |
if not os.path.exists("outputs"):
|
| 96 |
os.makedirs("outputs")
|
| 97 |
|
| 98 |
try:
|
| 99 |
api = HfApi(token=oauth_token.token)
|
| 100 |
+
|
| 101 |
+
if not base_model_id:
|
| 102 |
+
base_model_id = guess_base_model(ft_model_id)
|
| 103 |
+
print("guess_base_model", base_model_id)
|
| 104 |
|
| 105 |
with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
|
| 106 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 118 |
print("output_stdout", output_stdout)
|
| 119 |
print("output_stderr", output_stderr)
|
| 120 |
if returncode != 0:
|
| 121 |
+
raise Exception(f"Error converting to LoRA PEFT {output_stderr}")
|
| 122 |
print("Model converted to LoRA PEFT successfully!")
|
| 123 |
print(f"Converted model path: {outputdir}")
|
| 124 |
|
|
|
|
| 162 |
)
|
| 163 |
|
| 164 |
base_model_id = HuggingfaceHubSearch(
|
| 165 |
+
label="Base model repository (optional)",
|
| 166 |
+
placeholder="If empty, it will be guessed from repo tags",
|
| 167 |
search_type="model",
|
| 168 |
)
|
| 169 |
|