File size: 4,000 Bytes
ce7bf5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import json
import os
import tempfile

import requests

import chroma

ROOT_DIR = os.path.dirname(os.path.dirname(chroma.__file__))


def register_key(key: str, key_directory=ROOT_DIR) -> None:
    """
    Registers the provided key by saving it to a JSON file.

    Args:
        key (str): The access token to be registered.
        key_directory (str, optional): The directory where the access key is registered.

    Returns:
        None
    """
    config_path = os.path.join(key_directory, "config.json")
    with open(config_path, "w") as f:
        json.dump({"access_token": key}, f)


def read_key(key_directory=ROOT_DIR) -> str:
    """
    Reads the registered key from the JSON file. If no key has been registered,
    it informs the user and raises a FileNotFoundError.

    Args:
        key_directory (str, optional): The directory where the access key is registered.

    Returns:
        str: The registered access token.

    Raises:
        FileNotFoundError: If no key has been registered.
    """
    config_path = os.path.join(key_directory, "config.json")

    if not os.path.exists(config_path):
        print("No access token has been registered.")
        print(
            "To obtain an access token, go to https://chroma-weights.generatebiomedicines.com/ and agree to the license."
        )
        raise FileNotFoundError("No token has been registered.")

    with open(config_path, "r") as f:
        config = json.load(f)

    return config["access_token"]


def download_from_generate(
    base_url: str,
    weights_name: str,
    force: bool = False,
    exist_ok: bool = False,
    key_directory=ROOT_DIR,
) -> str:
    """
    Downloads data from the provided URL using the registered access token.
    Provides caching behavior based on force and exist_ok flags.

    Args:
        base_url (str): The base URL from which data should be fetched.
        force (bool): If True, always fetches data from the URL regardless of cache existence.
        exist_ok (bool): If True and cache exists (and force is False), uses the cached data.
        key_directory (str, optional): The directory where the access key is registered.

    Returns:
        str: Path to the downloaded (or cached) file.
    """

    # Create a hash of the URL + weight name to determine the path for the cached/temporary file
    url_hash = hashlib.md5((base_url + weights_name).encode()).hexdigest()
    temp_dir = os.path.join(tempfile.gettempdir(), "chroma_weights", url_hash)
    destination = os.path.join(temp_dir, "weights.pt")

    # Ensure the directory exists
    os.makedirs(temp_dir, exist_ok=True)

    # Check if cache exists
    cache_exists = os.path.exists(destination)

    # Determine if we should use the cache or not
    use_cache = cache_exists and exist_ok and not force

    if use_cache:
        print(f"Using cached data from {destination}")
        return destination

    # If not using cache, proceed with download

    # Define the query parameters
    params = {"token": read_key(key_directory), "weights": weights_name}

    # Perform the GET request with the token as a query parameter
    response = requests.get(base_url, params=params)
    response.raise_for_status()  # Raise an error for HTTP errors

    with open(destination, "wb") as file:
        file.write(response.content)

    print(f"Data saved to {destination}")
    return destination