|
""" |
|
Module: custom_agent |
|
|
|
This module provides a custom class, CustomHfAgent, for interacting with the Hugging Face model API. |
|
|
|
Dependencies: |
|
- time: Standard Python time module for time-related operations. |
|
- requests: HTTP library for making requests. |
|
- transformers: Hugging Face's transformers library for NLP tasks. |
|
- utils.logger: Custom logger module for logging responses. |
|
|
|
Classes: |
|
- CustomHfAgent: A custom class for interacting with the Hugging Face model API. |
|
""" |
|
|
|
import time |
|
import requests |
|
from transformers import Agent |
|
from utils.logger import log_response |
|
|
|
class CustomHfAgent(Agent): |
|
"""A custom class for interacting with the Hugging Face model API.""" |
|
|
|
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None): |
|
""" |
|
Initialize the CustomHfAgent. |
|
|
|
Args: |
|
- url_endpoint (str): The URL endpoint for the Hugging Face model API. |
|
- token (str): The authentication token required to access the API. |
|
- chat_prompt_template (str): Template for chat prompts. |
|
- run_prompt_template (str): Template for run prompts. |
|
- additional_tools (list): Additional tools for the agent. |
|
- input_params (dict): Additional parameters for input. |
|
|
|
Returns: |
|
- None |
|
""" |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
self.url_endpoint = url_endpoint |
|
self.token = token |
|
self.input_params = input_params |
|
|
|
def generate_one(self, prompt, stop): |
|
""" |
|
Generate one response from the Hugging Face model. |
|
|
|
Args: |
|
- prompt (str): The prompt to generate a response for. |
|
- stop (list): A list of strings indicating where to stop generating text. |
|
|
|
Returns: |
|
- str: The generated response. |
|
""" |
|
headers = {"Authorization": self.token} |
|
max_new_tokens = self.input_params.get("max_new_tokens", 192) |
|
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True} |
|
inputs = { |
|
"inputs": prompt, |
|
"parameters": parameters, |
|
} |
|
print(inputs) |
|
try: |
|
response = requests.post(self.url_endpoint, json=inputs, headers=headers, timeout=300) |
|
except requests.Timeout: |
|
pass |
|
except requests.ConnectionError: |
|
pass |
|
if response.status_code == 429: |
|
log_response("Getting rate-limited, waiting a tiny bit before trying again.") |
|
time.sleep(1) |
|
return self.generate_one(prompt, stop) |
|
elif response.status_code != 200: |
|
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") |
|
log_response(response) |
|
result = response.json()[0]["generated_text"] |
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
return result[: -len(stop_seq)] |
|
return result |
|
|