Chris4K commited on
Commit
55f4af6
·
verified ·
1 Parent(s): af9f214

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -45
app.py CHANGED
@@ -11,6 +11,9 @@ import pandas as pd # If you're working with DataFrames
11
  import matplotlib.figure # If you're using matplotlib figures
12
  import numpy as np
13
 
 
 
 
14
  # For Altair charts
15
  import altair as alt
16
 
@@ -33,19 +36,7 @@ transformers_logger = logging.getLogger("transformers.file_utils")
33
  transformers_logger.setLevel(logging.INFO) # Set the desired logging level
34
 
35
 
36
-
37
-
38
-
39
-
40
-
41
-
42
-
43
-
44
-
45
-
46
-
47
  import time
48
- from transformers import load_tool, Agent
49
  import torch
50
 
51
  class ToolLoader:
@@ -62,39 +53,6 @@ class ToolLoader:
62
  log_response(f"Error loading tool '{tool_name}': {e}")
63
  return loaded_tools
64
 
65
- class CustomHfAgent(Agent):
66
- def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
67
- super().__init__(
68
- chat_prompt_template=chat_prompt_template,
69
- run_prompt_template=run_prompt_template,
70
- additional_tools=additional_tools,
71
- )
72
- self.url_endpoint = url_endpoint
73
- self.token = token
74
- self.input_params = input_params
75
-
76
- def generate_one(self, prompt, stop):
77
- headers = {"Authorization": self.token}
78
- max_new_tokens = self.input_params.get("max_new_tokens", 192)
79
- parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
80
- inputs = {
81
- "inputs": prompt,
82
- "parameters": parameters,
83
- }
84
- response = requests.post(self.url_endpoint, json=inputs, headers=headers)
85
-
86
- if response.status_code == 429:
87
- log_response("Getting rate-limited, waiting a tiny bit before trying again.")
88
- time.sleep(1)
89
- return self._generate_one(prompt)
90
- elif response.status_code != 200:
91
- raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
92
- log_response(response)
93
- result = response.json()[0]["generated_text"]
94
- for stop_seq in stop:
95
- if result.endswith(stop_seq):
96
- return result[: -len(stop_seq)]
97
- return result
98
 
99
  def handle_submission(user_message, selected_tools, url_endpoint):
100
 
 
11
  import matplotlib.figure # If you're using matplotlib figures
12
  import numpy as np
13
 
14
+ from custom_agent import CustomHfAgent
15
+
16
+
17
  # For Altair charts
18
  import altair as alt
19
 
 
36
  transformers_logger.setLevel(logging.INFO) # Set the desired logging level
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
39
  import time
 
40
  import torch
41
 
42
  class ToolLoader:
 
53
  log_response(f"Error loading tool '{tool_name}': {e}")
54
  return loaded_tools
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def handle_submission(user_message, selected_tools, url_endpoint):
58