Chris4K's picture
static
bc53764
raw
history blame
3.14 kB
"""
Module: custom_agent
This module provides a custom class, CustomHfAgent, for interacting with the Hugging Face model API.
Dependencies:
- time: Standard Python time module for time-related operations.
- requests: HTTP library for making requests.
- transformers: Hugging Face's transformers library for NLP tasks.
- utils.logger: Custom logger module for logging responses.
Classes:
- CustomHfAgent: A custom class for interacting with the Hugging Face model API.
"""
import time
import requests
from transformers import Agent
from utils.logger import log_response
class CustomHfAgent(Agent):
"""A custom class for interacting with the Hugging Face model API."""
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
"""
Initialize the CustomHfAgent.
Args:
- url_endpoint (str): The URL endpoint for the Hugging Face model API.
- token (str): The authentication token required to access the API.
- chat_prompt_template (str): Template for chat prompts.
- run_prompt_template (str): Template for run prompts.
- additional_tools (list): Additional tools for the agent.
- input_params (dict): Additional parameters for input.
Returns:
- None
"""
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
self.url_endpoint = url_endpoint
self.token = token
self.input_params = input_params
def generate_one(self, prompt, stop):
"""
Generate one response from the Hugging Face model.
Args:
- prompt (str): The prompt to generate a response for.
- stop (list): A list of strings indicating where to stop generating text.
Returns:
- str: The generated response.
"""
headers = {"Authorization": self.token}
max_new_tokens = self.input_params.get("max_new_tokens", 192)
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
inputs = {
"inputs": prompt,
"parameters": parameters,
}
print(inputs)
try:
response = requests.post(self.url_endpoint, json=inputs, headers=headers, timeout=300)
except requests.Timeout:
pass
except requests.ConnectionError:
pass
if response.status_code == 429:
log_response("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self.generate_one(prompt, stop)
elif response.status_code != 200:
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
log_response(response)
result = response.json()[0]["generated_text"]
for stop_seq in stop:
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result