Spaces:
Runtime error
Runtime error
import pickle | |
import time | |
from grazie.api.client.chat.prompt import ChatPrompt | |
from grazie.api.client.endpoints import GrazieApiGatewayUrls | |
from grazie.api.client.gateway import AuthType, GrazieAgent, GrazieApiGatewayClient | |
from grazie.api.client.profiles import LLMProfile | |
import config | |
client = GrazieApiGatewayClient( | |
grazie_agent=GrazieAgent("grazie-toolformers", "v1.0"), | |
url=GrazieApiGatewayUrls.STAGING, | |
auth_type=AuthType.APPLICATION, | |
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN, | |
) | |
LLM_CACHE_FILE = config.CACHE_DIR / f"{config.LLM_MODEL}.cache.pkl" | |
LLM_CACHE = {} | |
LLM_CACHE_USED = {} | |
if not LLM_CACHE_FILE.exists(): | |
with open(LLM_CACHE_FILE, "wb") as file: | |
pickle.dump(obj=LLM_CACHE, file=file) | |
with open(LLM_CACHE_FILE, "rb") as file: | |
LLM_CACHE = pickle.load(file=file) | |
def llm_request(prompt): | |
output = None | |
while output is None: | |
try: | |
output = client.chat( | |
chat=ChatPrompt().add_system("You are a helpful assistant.").add_user(prompt), | |
profile=LLMProfile(config.LLM_MODEL), | |
).content | |
except Exception: | |
time.sleep(config.GRAZIE_TIMEOUT_SEC) | |
assert output is not None | |
return output | |
def generate_for_prompt(prompt): | |
if prompt not in LLM_CACHE: | |
LLM_CACHE[prompt] = [] | |
if prompt not in LLM_CACHE_USED: | |
LLM_CACHE_USED[prompt] = 0 | |
while LLM_CACHE_USED[prompt] >= len(LLM_CACHE[prompt]): | |
new_response = llm_request(prompt) | |
LLM_CACHE[prompt].append(new_response) | |
with open(LLM_CACHE_FILE, "wb") as file: | |
pickle.dump(obj=LLM_CACHE, file=file) | |
result = LLM_CACHE[prompt][LLM_CACHE_USED[prompt]] | |
LLM_CACHE_USED[prompt] += 1 | |
return result | |