|
import streamlit as st |
|
import os |
|
import requests |
|
|
|
|
|
from transformers import HfAgent, load_tool |
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
controlnet_transformer = load_tool("huggingface-tools/text-to-image") |
|
upscaler = load_tool("diffusers/latent-upscaler-tool") |
|
|
|
|
|
tools = [controlnet_transformer, upscaler] |
|
|
|
|
|
class CustomHfAgent(Agent): |
|
def __init__( |
|
self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None |
|
): |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
self.url_endpoint = url_endpoint |
|
self.token = token |
|
|
|
def generate_one(self, prompt, stop): |
|
headers = {"Authorization": self.token} |
|
inputs = { |
|
"inputs": prompt, |
|
"parameters": {"max_new_tokens": 192, "return_full_text": False, "stop": stop}, |
|
} |
|
response = requests.post(self.url_endpoint, json=inputs, headers=headers) |
|
if response.status_code == 429: |
|
print("Getting rate-limited, waiting a tiny bit before trying again.") |
|
time.sleep(1) |
|
return self._generate_one(prompt) |
|
elif response.status_code != 200: |
|
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") |
|
|
|
result = response.json()[0]["generated_text"] |
|
|
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
return result[: -len(stop_seq)] |
|
return result |
|
|
|
|
|
st.title("Hugging Face Agent") |
|
|
|
|
|
message_input = st.text_input("Enter your message:", "") |
|
|
|
|
|
tool_checkboxes = [st.checkbox(f"Use {tool}") for tool in tools] |
|
|
|
|
|
submit_button = st.button("Submit") |
|
|
|
|
|
def handle_submission(): |
|
|
|
message = message_input.value |
|
selected_tools = [tool for tool, checkbox in tool_checkboxes] |
|
|
|
|
|
agent = CustomHfAgent(url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", token=os.environ['HF_token']) |
|
|
|
|
|
response = agent.run(message, tools=selected_tools) |
|
|
|
|
|
|
|
if response.startswith("Image:"): |
|
|
|
image_data = base64.b64decode(response.split(",")[1]) |
|
img = Image.open(io.BytesIO(image_data)) |
|
st.image(img) |
|
else: |
|
|
|
st.write(response) |
|
|
|
|
|
st.button("Ask Again") |
|
|
|
|
|
def ask_again(): |
|
|
|
message_input.value = "" |
|
|
|
|
|
agent.run("") |
|
|
|
|
|
st.button("Ask Again").do(ask_again) |