Spaces:
Running
Running
import os | |
import re | |
import sys | |
import shutil | |
import zipfile | |
import requests | |
from bs4 import BeautifulSoup | |
from urllib.parse import unquote | |
from tqdm import tqdm | |
now_dir = os.getcwd() | |
sys.path.append(now_dir) | |
from rvc.lib.utils import format_title | |
from rvc.lib.tools import gdown | |
file_path = os.path.join(now_dir, "logs") | |
zips_path = os.path.join(file_path, "zips") | |
os.makedirs(zips_path, exist_ok=True) | |
def search_pth_index(folder): | |
pth_paths = [ | |
os.path.join(folder, file) | |
for file in os.listdir(folder) | |
if os.path.isfile(os.path.join(folder, file)) and file.endswith(".pth") | |
] | |
index_paths = [ | |
os.path.join(folder, file) | |
for file in os.listdir(folder) | |
if os.path.isfile(os.path.join(folder, file)) and file.endswith(".index") | |
] | |
return pth_paths, index_paths | |
def download_from_url(url): | |
os.chdir(zips_path) | |
try: | |
if "drive.google.com" in url: | |
file_id = extract_google_drive_id(url) | |
if file_id: | |
gdown.download( | |
url=f"https://drive.google.com/uc?id={file_id}", | |
quiet=False, | |
fuzzy=True, | |
) | |
elif "/blob/" in url or "/resolve/" in url: | |
download_blob_or_resolve(url) | |
elif "/tree/main" in url: | |
download_from_huggingface(url) | |
else: | |
download_file(url) | |
rename_downloaded_files() | |
return "downloaded" | |
except Exception as error: | |
print(f"An error occurred downloading the file: {error}") | |
return None | |
finally: | |
os.chdir(now_dir) | |
def extract_google_drive_id(url): | |
if "file/d/" in url: | |
return url.split("file/d/")[1].split("/")[0] | |
if "id=" in url: | |
return url.split("id=")[1].split("&")[0] | |
return None | |
def download_blob_or_resolve(url): | |
if "/blob/" in url: | |
url = url.replace("/blob/", "/resolve/") | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
save_response_content(response) | |
else: | |
raise ValueError( | |
"Download failed with status code: " + str(response.status_code) | |
) | |
def save_response_content(response): | |
content_disposition = unquote(response.headers.get("Content-Disposition", "")) | |
file_name = ( | |
re.search(r'filename="([^"]+)"', content_disposition) | |
.groups()[0] | |
.replace(os.path.sep, "_") | |
if content_disposition | |
else "downloaded_file" | |
) | |
total_size = int(response.headers.get("Content-Length", 0)) | |
chunk_size = 1024 | |
with open(os.path.join(zips_path, file_name), "wb") as file, tqdm( | |
total=total_size, unit="B", unit_scale=True, desc=file_name | |
) as progress_bar: | |
for data in response.iter_content(chunk_size): | |
file.write(data) | |
progress_bar.update(len(data)) | |
def download_from_huggingface(url): | |
response = requests.get(url) | |
soup = BeautifulSoup(response.content, "html.parser") | |
temp_url = next( | |
( | |
link["href"] | |
for link in soup.find_all("a", href=True) | |
if link["href"].endswith(".zip") | |
), | |
None, | |
) | |
if temp_url: | |
url = temp_url.replace("blob", "resolve") | |
if "huggingface.co" not in url: | |
url = "https://huggingface.co" + url | |
download_file(url) | |
else: | |
raise ValueError("No zip file found in Huggingface URL") | |
def download_file(url): | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
save_response_content(response) | |
else: | |
raise ValueError( | |
"Download failed with status code: " + str(response.status_code) | |
) | |
def rename_downloaded_files(): | |
for currentPath, _, zipFiles in os.walk(zips_path): | |
for file in zipFiles: | |
file_name, extension = os.path.splitext(file) | |
real_path = os.path.join(currentPath, file) | |
os.rename(real_path, file_name.replace(os.path.sep, "_") + extension) | |
def extract(zipfile_path, unzips_path): | |
try: | |
with zipfile.ZipFile(zipfile_path, "r") as zip_ref: | |
zip_ref.extractall(unzips_path) | |
os.remove(zipfile_path) | |
return True | |
except Exception as error: | |
print(f"An error occurred extracting the zip file: {error}") | |
return False | |
def unzip_file(zip_path, zip_file_name): | |
zip_file_path = os.path.join(zip_path, zip_file_name + ".zip") | |
extract_path = os.path.join(file_path, zip_file_name) | |
with zipfile.ZipFile(zip_file_path, "r") as zip_ref: | |
zip_ref.extractall(extract_path) | |
os.remove(zip_file_path) | |
def model_download_pipeline(url: str): | |
try: | |
result = download_from_url(url) | |
if result == "downloaded": | |
return handle_extraction_process() | |
else: | |
return "Error" | |
except Exception as error: | |
print(f"An unexpected error occurred: {error}") | |
return "Error" | |
def handle_extraction_process(): | |
extract_folder_path = "" | |
for filename in os.listdir(zips_path): | |
if filename.endswith(".zip"): | |
zipfile_path = os.path.join(zips_path, filename) | |
model_name = format_title(os.path.basename(zipfile_path).split(".zip")[0]) | |
extract_folder_path = os.path.join("logs", os.path.normpath(model_name)) | |
success = extract(zipfile_path, extract_folder_path) | |
clean_extracted_files(extract_folder_path, model_name) | |
if success: | |
print(f"Model {model_name} downloaded!") | |
else: | |
print(f"Error downloading {model_name}") | |
return "Error" | |
if not extract_folder_path: | |
print("Zip file was not found.") | |
return "Error" | |
return search_pth_index(extract_folder_path) | |
def clean_extracted_files(extract_folder_path, model_name): | |
macosx_path = os.path.join(extract_folder_path, "__MACOSX") | |
if os.path.exists(macosx_path): | |
shutil.rmtree(macosx_path) | |
subfolders = [ | |
f | |
for f in os.listdir(extract_folder_path) | |
if os.path.isdir(os.path.join(extract_folder_path, f)) | |
] | |
if len(subfolders) == 1: | |
subfolder_path = os.path.join(extract_folder_path, subfolders[0]) | |
for item in os.listdir(subfolder_path): | |
shutil.move( | |
os.path.join(subfolder_path, item), | |
os.path.join(extract_folder_path, item), | |
) | |
os.rmdir(subfolder_path) | |
for item in os.listdir(extract_folder_path): | |
source_path = os.path.join(extract_folder_path, item) | |
if ".pth" in item: | |
new_file_name = model_name + ".pth" | |
elif ".index" in item: | |
new_file_name = model_name + ".index" | |
else: | |
continue | |
destination_path = os.path.join(extract_folder_path, new_file_name) | |
if not os.path.exists(destination_path): | |
os.rename(source_path, destination_path) | |