File size: 3,864 Bytes
fdfb0c4 45f17fe d6555d8 6bab521 b066a4d 3709e0d d6555d8 3709e0d 6bab521 3709e0d 6bab521 4dc413a 6bab521 4dc413a b066a4d 7755f96 6bab521 7755f96 6bab521 7755f96 45f17fe 7755f96 6bab521 7755f96 b066a4d 8fe5a03 b066a4d 6bab521 b066a4d 6bab521 b066a4d 6bab521 b066a4d 6bab521 eaf08d2 4541439 eaf08d2 8fe5a03 6bab521 4541439 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import streamlit as st
import os
import requests
# From transformers import BertModel, BertTokenizer
from transformers import HfAgent, load_tool
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent
# checkpoint = "THUDM/agentlm-7b"
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# agent = LocalAgent(model, tokenizer)
# agent.run("Draw me a picture of rivers and lakes.")
# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!"))
# Load tools
controlnet_transformer = load_tool("huggingface-tools/text-to-image")
upscaler = load_tool("diffusers/latent-upscaler-tool")
tools = [controlnet_transformer, upscaler]
# Define the custom HfAgent class
class CustomHfAgent(Agent):
def __init__(
self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None
):
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
self.url_endpoint = url_endpoint
self.token = token
def generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
inputs = {
"inputs": prompt,
"parameters": {"max_new_tokens": 192, "return_full_text": False, "stop": stop},
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
result = response.json()[0]["generated_text"]
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result
# Create the Streamlit app
st.title("Hugging Face Agent")
# Input field for the user's message
message_input = st.text_input("Enter your message:", "")
# Checkboxes for the tools to be used by the agent
tool_checkboxes = [st.checkbox(f"Use {tool}") for tool in tools]
# Submit button
#submit_button = st.button("Submit")
# Define the callback function to handle the form submission
def handle_submission():
# Get the user's message and the selected tools
message = message_input.value
selected_tools = [tool for tool, checkbox in tool_checkboxes]
# Initialize the agent
agent = CustomHfAgent(url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", token=os.environ['HF_token'])
# Run the agent with the user's message and selected tools
response = agent.run(message, tools=selected_tools)
# Display the agent's response
# Display the agent's response
if response.startswith("Image:"):
# Display the image response
image_data = base64.b64decode(response.split(",")[1])
img = Image.open(io.BytesIO(image_data))
st.image(img)
else:
# Display the text response
st.write(response)
# Add a button to trigger the agent to respond again
#st.button("Ask Again
#st.button("Ask Again", key="ask_again_btn")
st.button("Ask Again")(handle_submission)
# Define a callback function to handle the button click
def ask_again():
# Reset the message input field
message_input.value = ""
# Run the agent again with an empty message
agent.run("")
# Add the callback function to the button
#st.button("Ask Again").do(ask_again) |