Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,9 @@ import pandas as pd # If you're working with DataFrames
|
|
11 |
import matplotlib.figure # If you're using matplotlib figures
|
12 |
import numpy as np
|
13 |
|
|
|
|
|
|
|
14 |
# For Altair charts
|
15 |
import altair as alt
|
16 |
|
@@ -33,19 +36,7 @@ transformers_logger = logging.getLogger("transformers.file_utils")
|
|
33 |
transformers_logger.setLevel(logging.INFO) # Set the desired logging level
|
34 |
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
import time
|
48 |
-
from transformers import load_tool, Agent
|
49 |
import torch
|
50 |
|
51 |
class ToolLoader:
|
@@ -62,39 +53,6 @@ class ToolLoader:
|
|
62 |
log_response(f"Error loading tool '{tool_name}': {e}")
|
63 |
return loaded_tools
|
64 |
|
65 |
-
class CustomHfAgent(Agent):
|
66 |
-
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
|
67 |
-
super().__init__(
|
68 |
-
chat_prompt_template=chat_prompt_template,
|
69 |
-
run_prompt_template=run_prompt_template,
|
70 |
-
additional_tools=additional_tools,
|
71 |
-
)
|
72 |
-
self.url_endpoint = url_endpoint
|
73 |
-
self.token = token
|
74 |
-
self.input_params = input_params
|
75 |
-
|
76 |
-
def generate_one(self, prompt, stop):
|
77 |
-
headers = {"Authorization": self.token}
|
78 |
-
max_new_tokens = self.input_params.get("max_new_tokens", 192)
|
79 |
-
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
|
80 |
-
inputs = {
|
81 |
-
"inputs": prompt,
|
82 |
-
"parameters": parameters,
|
83 |
-
}
|
84 |
-
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
|
85 |
-
|
86 |
-
if response.status_code == 429:
|
87 |
-
log_response("Getting rate-limited, waiting a tiny bit before trying again.")
|
88 |
-
time.sleep(1)
|
89 |
-
return self._generate_one(prompt)
|
90 |
-
elif response.status_code != 200:
|
91 |
-
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
|
92 |
-
log_response(response)
|
93 |
-
result = response.json()[0]["generated_text"]
|
94 |
-
for stop_seq in stop:
|
95 |
-
if result.endswith(stop_seq):
|
96 |
-
return result[: -len(stop_seq)]
|
97 |
-
return result
|
98 |
|
99 |
def handle_submission(user_message, selected_tools, url_endpoint):
|
100 |
|
|
|
11 |
import matplotlib.figure # If you're using matplotlib figures
|
12 |
import numpy as np
|
13 |
|
14 |
+
from custom_agent import CustomHfAgent
|
15 |
+
|
16 |
+
|
17 |
# For Altair charts
|
18 |
import altair as alt
|
19 |
|
|
|
36 |
transformers_logger.setLevel(logging.INFO) # Set the desired logging level
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
import time
|
|
|
40 |
import torch
|
41 |
|
42 |
class ToolLoader:
|
|
|
53 |
log_response(f"Error loading tool '{tool_name}': {e}")
|
54 |
return loaded_tools
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
def handle_submission(user_message, selected_tools, url_endpoint):
|
58 |
|