File size: 3,864 Bytes
fdfb0c4
45f17fe
d6555d8
 
6bab521
b066a4d
 
3709e0d
d6555d8
3709e0d
6bab521
 
 
3709e0d
6bab521
 
4dc413a
6bab521
4dc413a
b066a4d
 
 
 
7755f96
6bab521
7755f96
6bab521
7755f96
 
45f17fe
7755f96
 
 
 
 
 
6bab521
 
7755f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b066a4d
 
 
 
 
 
 
 
 
 
8fe5a03
b066a4d
 
 
 
6bab521
 
b066a4d
6bab521
 
b066a4d
6bab521
 
b066a4d
 
6bab521
 
 
 
 
 
 
 
 
 
 
eaf08d2
4541439
eaf08d2
8fe5a03
6bab521
 
 
 
 
 
 
 
 
 
4541439
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import streamlit as st
import os
import requests

# From transformers import BertModel, BertTokenizer
from transformers import HfAgent, load_tool

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent

# checkpoint = "THUDM/agentlm-7b"
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# agent = LocalAgent(model, tokenizer)
# agent.run("Draw me a picture of rivers and lakes.")

# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!"))

# Load tools
controlnet_transformer = load_tool("huggingface-tools/text-to-image")
upscaler = load_tool("diffusers/latent-upscaler-tool")


tools = [controlnet_transformer, upscaler]

# Define the custom HfAgent class
class CustomHfAgent(Agent):
    def __init__(
        self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None
    ):
        super().__init__(
            chat_prompt_template=chat_prompt_template,
            run_prompt_template=run_prompt_template,
            additional_tools=additional_tools,
        )
        self.url_endpoint = url_endpoint
        self.token = token

    def generate_one(self, prompt, stop):
        headers = {"Authorization": self.token}
        inputs = {
            "inputs": prompt,
            "parameters": {"max_new_tokens": 192, "return_full_text": False, "stop": stop},
        }
        response = requests.post(self.url_endpoint, json=inputs, headers=headers)
        if response.status_code == 429:
            print("Getting rate-limited, waiting a tiny bit before trying again.")
            time.sleep(1)
            return self._generate_one(prompt)
        elif response.status_code != 200:
            raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")

        result = response.json()[0]["generated_text"]
        # Inference API returns the stop sequence
        for stop_seq in stop:
            if result.endswith(stop_seq):
                return result[: -len(stop_seq)]
        return result

# Create the Streamlit app
st.title("Hugging Face Agent")

# Input field for the user's message
message_input = st.text_input("Enter your message:", "")

# Checkboxes for the tools to be used by the agent
tool_checkboxes = [st.checkbox(f"Use {tool}") for tool in tools]

# Submit button
#submit_button = st.button("Submit")

# Define the callback function to handle the form submission
def handle_submission():
    # Get the user's message and the selected tools
    message = message_input.value
    selected_tools = [tool for tool, checkbox in tool_checkboxes]

    # Initialize the agent
    agent = CustomHfAgent(url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", token=os.environ['HF_token'])

    # Run the agent with the user's message and selected tools
    response = agent.run(message, tools=selected_tools)

    # Display the agent's response
    # Display the agent's response
    if response.startswith("Image:"):
        # Display the image response
        image_data = base64.b64decode(response.split(",")[1])
        img = Image.open(io.BytesIO(image_data))
        st.image(img)
    else:
        # Display the text response
        st.write(response)

# Add a button to trigger the agent to respond again
#st.button("Ask Again
#st.button("Ask Again", key="ask_again_btn")

st.button("Ask Again")(handle_submission)

# Define a callback function to handle the button click
def ask_again():
    # Reset the message input field
    message_input.value = ""

    # Run the agent again with an empty message
    agent.run("")

# Add the callback function to the button
#st.button("Ask Again").do(ask_again)