|
import streamlit as st |
|
import secrets |
|
|
|
|
|
from transformers import HfAgent, load_tool |
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
controlnet_transformer = load_tool("huggingface-tools/text-to-image") |
|
upscaler = load_tool("diffusers/latent-upscaler-tool") |
|
|
|
tools = [controlnet_transformer, upscaler ] |
|
|
|
|
|
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
|
from transformers.tools import HfAgent |
|
from transformers.tools import Agent |
|
|
|
|
|
import time |
|
|
|
from huggingface_hub import HfFolder, hf_hub_download, list_spaces |
|
|
|
|
|
|
|
|
|
class CustomHfAgent(Agent): |
|
""" |
|
Agent that uses an inference endpoint to generate code. |
|
|
|
Args: |
|
url_endpoint (`str`): |
|
The name of the url endpoint to use. |
|
token (`str`, *optional*): |
|
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when |
|
running `huggingface-cli login` (stored in `~/.huggingface`). |
|
chat_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`chat_prompt_template.txt` in this repo in this case. |
|
run_prompt_template (`str`, *optional*): |
|
Pass along your own prompt if you want to override the default template for the `run` method. Can be the |
|
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named |
|
`run_prompt_template.txt` in this repo in this case. |
|
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): |
|
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as |
|
one of the default tools, that default tool will be overridden. |
|
|
|
Example: |
|
|
|
```py |
|
from transformers import HfAgent |
|
|
|
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder") |
|
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!") |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, url_endpoint, token=secrets.HF_token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None |
|
): |
|
|
|
self.url_endpoint = url_endpoint |
|
if token is None: |
|
self.token = f"Bearer {HfFolder().get_token()}" |
|
elif token.startswith("Bearer") or token.startswith("Basic"): |
|
self.token = token |
|
else: |
|
self.token = f"Bearer {token}" |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
|
|
def generate_one(self, prompt, stop): |
|
headers = {"Authorization": self.token} |
|
inputs = { |
|
"inputs": prompt, |
|
"parameters": {"max_new_tokens": 192, "return_full_text": False, "stop": stop}, |
|
} |
|
print(inputs) |
|
response = requests.post(self.url_endpoint, json=inputs, headers=headers) |
|
if response.status_code == 429: |
|
print("Getting rate-limited, waiting a tiny bit before trying again.") |
|
time.sleep(1) |
|
return self._generate_one(prompt) |
|
elif response.status_code != 200: |
|
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") |
|
|
|
result = response.json()[0]["generated_text"] |
|
|
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
return result[: -len(stop_seq)] |
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent = CustomHfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("Hugging Face Agent") |
|
|
|
|
|
message_input = st.text_input("Enter your message:", "") |
|
|
|
|
|
tool_checkboxes = [st.checkbox(f"Use {tool}") for tool in tools] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_submission(): |
|
|
|
message = message_input |
|
selected_tools = [tool for tool, checkbox in zip(tools, tool_checkboxes) if checkbox] |
|
|
|
|
|
|
|
|
|
agent = HfAgent("https://api-inference.huggingface.co/models/THUDM/agentlm-7b", additional_tools=tools) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = agent.run(message) |
|
st.text(f"{response:.4f}") |
|
return "done" |
|
|
|
|
|
|
|
submit_button = st.button("Submit", on_click=handle_submission) |
|
|
|
|