Spaces:
Sleeping
Sleeping
File size: 4,000 Bytes
ce7bf5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import os
import tempfile
import requests
import chroma
ROOT_DIR = os.path.dirname(os.path.dirname(chroma.__file__))
def register_key(key: str, key_directory=ROOT_DIR) -> None:
"""
Registers the provided key by saving it to a JSON file.
Args:
key (str): The access token to be registered.
key_directory (str, optional): The directory where the access key is registered.
Returns:
None
"""
config_path = os.path.join(key_directory, "config.json")
with open(config_path, "w") as f:
json.dump({"access_token": key}, f)
def read_key(key_directory=ROOT_DIR) -> str:
"""
Reads the registered key from the JSON file. If no key has been registered,
it informs the user and raises a FileNotFoundError.
Args:
key_directory (str, optional): The directory where the access key is registered.
Returns:
str: The registered access token.
Raises:
FileNotFoundError: If no key has been registered.
"""
config_path = os.path.join(key_directory, "config.json")
if not os.path.exists(config_path):
print("No access token has been registered.")
print(
"To obtain an access token, go to https://chroma-weights.generatebiomedicines.com/ and agree to the license."
)
raise FileNotFoundError("No token has been registered.")
with open(config_path, "r") as f:
config = json.load(f)
return config["access_token"]
def download_from_generate(
base_url: str,
weights_name: str,
force: bool = False,
exist_ok: bool = False,
key_directory=ROOT_DIR,
) -> str:
"""
Downloads data from the provided URL using the registered access token.
Provides caching behavior based on force and exist_ok flags.
Args:
base_url (str): The base URL from which data should be fetched.
force (bool): If True, always fetches data from the URL regardless of cache existence.
exist_ok (bool): If True and cache exists (and force is False), uses the cached data.
key_directory (str, optional): The directory where the access key is registered.
Returns:
str: Path to the downloaded (or cached) file.
"""
# Create a hash of the URL + weight name to determine the path for the cached/temporary file
url_hash = hashlib.md5((base_url + weights_name).encode()).hexdigest()
temp_dir = os.path.join(tempfile.gettempdir(), "chroma_weights", url_hash)
destination = os.path.join(temp_dir, "weights.pt")
# Ensure the directory exists
os.makedirs(temp_dir, exist_ok=True)
# Check if cache exists
cache_exists = os.path.exists(destination)
# Determine if we should use the cache or not
use_cache = cache_exists and exist_ok and not force
if use_cache:
print(f"Using cached data from {destination}")
return destination
# If not using cache, proceed with download
# Define the query parameters
params = {"token": read_key(key_directory), "weights": weights_name}
# Perform the GET request with the token as a query parameter
response = requests.get(base_url, params=params)
response.raise_for_status() # Raise an error for HTTP errors
with open(destination, "wb") as file:
file.write(response.content)
print(f"Data saved to {destination}")
return destination
|