Spaces:
Sleeping
Sleeping
# Copyright Generate Biomedicines, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import hashlib | |
import json | |
import os | |
import tempfile | |
import requests | |
import chroma | |
ROOT_DIR = os.path.dirname(os.path.dirname(chroma.__file__)) | |
def register_key(key: str, key_directory=ROOT_DIR) -> None: | |
""" | |
Registers the provided key by saving it to a JSON file. | |
Args: | |
key (str): The access token to be registered. | |
key_directory (str, optional): The directory where the access key is registered. | |
Returns: | |
None | |
""" | |
config_path = os.path.join(key_directory, "config.json") | |
with open(config_path, "w") as f: | |
json.dump({"access_token": key}, f) | |
def read_key(key_directory=ROOT_DIR) -> str: | |
""" | |
Reads the registered key from the JSON file. If no key has been registered, | |
it informs the user and raises a FileNotFoundError. | |
Args: | |
key_directory (str, optional): The directory where the access key is registered. | |
Returns: | |
str: The registered access token. | |
Raises: | |
FileNotFoundError: If no key has been registered. | |
""" | |
config_path = os.path.join(key_directory, "config.json") | |
if not os.path.exists(config_path): | |
print("No access token has been registered.") | |
print( | |
"To obtain an access token, go to https://chroma-weights.generatebiomedicines.com/ and agree to the license." | |
) | |
raise FileNotFoundError("No token has been registered.") | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
return config["access_token"] | |
def download_from_generate( | |
base_url: str, | |
weights_name: str, | |
force: bool = False, | |
exist_ok: bool = False, | |
key_directory=ROOT_DIR, | |
) -> str: | |
""" | |
Downloads data from the provided URL using the registered access token. | |
Provides caching behavior based on force and exist_ok flags. | |
Args: | |
base_url (str): The base URL from which data should be fetched. | |
force (bool): If True, always fetches data from the URL regardless of cache existence. | |
exist_ok (bool): If True and cache exists (and force is False), uses the cached data. | |
key_directory (str, optional): The directory where the access key is registered. | |
Returns: | |
str: Path to the downloaded (or cached) file. | |
""" | |
# Create a hash of the URL + weight name to determine the path for the cached/temporary file | |
url_hash = hashlib.md5((base_url + weights_name).encode()).hexdigest() | |
temp_dir = os.path.join(tempfile.gettempdir(), "chroma_weights", url_hash) | |
destination = os.path.join(temp_dir, "weights.pt") | |
# Ensure the directory exists | |
os.makedirs(temp_dir, exist_ok=True) | |
# Check if cache exists | |
cache_exists = os.path.exists(destination) | |
# Determine if we should use the cache or not | |
use_cache = cache_exists and exist_ok and not force | |
if use_cache: | |
print(f"Using cached data from {destination}") | |
return destination | |
# If not using cache, proceed with download | |
# Define the query parameters | |
params = {"token": read_key(key_directory), "weights": weights_name} | |
# Perform the GET request with the token as a query parameter | |
response = requests.get(base_url, params=params) | |
response.raise_for_status() # Raise an error for HTTP errors | |
with open(destination, "wb") as file: | |
file.write(response.content) | |
print(f"Data saved to {destination}") | |
return destination | |