File size: 3,775 Bytes
fdfb0c4
45f17fe
d6555d8
 
6bab521
b066a4d
 
3709e0d
d6555d8
3709e0d
6bab521
 
 
3709e0d
6bab521
 
4dc413a
6bab521
4dc413a
b066a4d
 
 
 
7755f96
6bab521
7755f96
6bab521
7755f96
 
45f17fe
7755f96
 
 
 
 
 
6bab521
 
7755f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b066a4d
 
 
 
 
 
 
 
 
 
6bab521
b066a4d
 
 
 
6bab521
 
b066a4d
6bab521
 
b066a4d
6bab521
 
b066a4d
 
6bab521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import os
import requests

# From transformers import BertModel, BertTokenizer
from transformers import HfAgent, load_tool

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent

# checkpoint = "THUDM/agentlm-7b"
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# agent = LocalAgent(model, tokenizer)
# agent.run("Draw me a picture of rivers and lakes.")

# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!"))

# Load tools
controlnet_transformer = load_tool("huggingface-tools/text-to-image")
upscaler = load_tool("diffusers/latent-upscaler-tool")


tools = [controlnet_transformer, upscaler]

# Define the custom HfAgent class
class CustomHfAgent(Agent):
    def __init__(
        self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None
    ):
        super().__init__(
            chat_prompt_template=chat_prompt_template,
            run_prompt_template=run_prompt_template,
            additional_tools=additional_tools,
        )
        self.url_endpoint = url_endpoint
        self.token = token

    def generate_one(self, prompt, stop):
        headers = {"Authorization": self.token}
        inputs = {
            "inputs": prompt,
            "parameters": {"max_new_tokens": 192, "return_full_text": False, "stop": stop},
        }
        response = requests.post(self.url_endpoint, json=inputs, headers=headers)
        if response.status_code == 429:
            print("Getting rate-limited, waiting a tiny bit before trying again.")
            time.sleep(1)
            return self._generate_one(prompt)
        elif response.status_code != 200:
            raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")

        result = response.json()[0]["generated_text"]
        # Inference API returns the stop sequence
        for stop_seq in stop:
            if result.endswith(stop_seq):
                return result[: -len(stop_seq)]
        return result

# Create the Streamlit app
st.title("Hugging Face Agent")

# Input field for the user's message
message_input = st.text_input("Enter your message:", "")

# Checkboxes for the tools to be used by the agent
tool_checkboxes = [st.checkbox(f"Use {tool}") for tool in tools]

# Submit button
submit_button = st.button("Submit")

# Define the callback function to handle the form submission
def handle_submission():
    # Get the user's message and the selected tools
    message = message_input.value
    selected_tools = [tool for tool, checkbox in tool_checkboxes]

    # Initialize the agent
    agent = CustomHfAgent(url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", token=os.environ['HF_token'])

    # Run the agent with the user's message and selected tools
    response = agent.run(message, tools=selected_tools)

    # Display the agent's response
    # Display the agent's response
    if response.startswith("Image:"):
        # Display the image response
        image_data = base64.b64decode(response.split(",")[1])
        img = Image.open(io.BytesIO(image_data))
        st.image(img)
    else:
        # Display the text response
        st.write(response)

# Add a button to trigger the agent to respond again
st.button("Ask Again")

# Define a callback function to handle the button click
def ask_again():
    # Reset the message input field
    message_input.value = ""

    # Run the agent again with an empty message
    agent.run("")

# Add the callback function to the button
st.button("Ask Again").do(ask_again)