File size: 3,143 Bytes
bc53764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af9f214
bc53764
af9f214
b8b0b89
438881a
af9f214
bc53764
 
af9f214
bc53764
 
 
 
 
 
 
 
 
 
 
 
 
 
af9f214
 
 
 
 
 
 
 
 
 
bc53764
 
 
 
 
 
 
 
 
 
af9f214
 
 
 
 
 
 
bc53764
 
 
 
 
 
 
af9f214
 
 
bc53764
af9f214
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Module: custom_agent

This module provides a custom class, CustomHfAgent, for interacting with the Hugging Face model API.

Dependencies:
- time: Standard Python time module for time-related operations.
- requests: HTTP library for making requests.
- transformers: Hugging Face's transformers library for NLP tasks.
- utils.logger: Custom logger module for logging responses.

Classes:
- CustomHfAgent: A custom class for interacting with the Hugging Face model API.
"""

import time
import requests
from transformers import Agent
from utils.logger import log_response

class CustomHfAgent(Agent):
    """A custom class for interacting with the Hugging Face model API."""

    def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
        """
        Initialize the CustomHfAgent.

        Args:
        - url_endpoint (str): The URL endpoint for the Hugging Face model API.
        - token (str): The authentication token required to access the API.
        - chat_prompt_template (str): Template for chat prompts.
        - run_prompt_template (str): Template for run prompts.
        - additional_tools (list): Additional tools for the agent.
        - input_params (dict): Additional parameters for input.

        Returns:
        - None
        """
        super().__init__(
            chat_prompt_template=chat_prompt_template,
            run_prompt_template=run_prompt_template,
            additional_tools=additional_tools,
        )
        self.url_endpoint = url_endpoint
        self.token = token
        self.input_params = input_params

    def generate_one(self, prompt, stop):
        """
        Generate one response from the Hugging Face model.

        Args:
        - prompt (str): The prompt to generate a response for.
        - stop (list): A list of strings indicating where to stop generating text.

        Returns:
        - str: The generated response.
        """
        headers = {"Authorization": self.token}
        max_new_tokens = self.input_params.get("max_new_tokens", 192)
        parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
        inputs = {
            "inputs": prompt,
            "parameters": parameters,
        }
        print(inputs)
        try:
            response = requests.post(self.url_endpoint, json=inputs, headers=headers, timeout=300)
        except requests.Timeout:
            pass
        except requests.ConnectionError:
            pass
        if response.status_code == 429:
            log_response("Getting rate-limited, waiting a tiny bit before trying again.")
            time.sleep(1)
            return self.generate_one(prompt, stop)
        elif response.status_code != 200:
            raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
        log_response(response)
        result = response.json()[0]["generated_text"]
        for stop_seq in stop:
            if result.endswith(stop_seq):
                return result[: -len(stop_seq)]
        return result