File size: 6,066 Bytes
fdfb0c4
45f17fe
d6555d8
 
d7ca359
d2b3f71
6cc2332
 
d7ca359
6bab521
b066a4d
 
3709e0d
d6555d8
3709e0d
6bab521
 
 
3709e0d
6bab521
 
4dc413a
6bab521
4dc413a
b066a4d
e58825b
 
b066a4d
7755f96
6bab521
7755f96
6bab521
7755f96
 
45f17fe
7755f96
 
 
 
 
 
6bab521
 
7755f96
 
 
 
 
 
 
 
 
 
 
 
 
 
e58825b
7755f96
 
 
 
 
 
 
b066a4d
 
 
 
e0fddbb
b066a4d
 
4da4c03
e58825b
 
b066a4d
 
8fe5a03
b066a4d
 
 
 
e0fddbb
4da4c03
dfe0aa8
 
 
 
 
 
 
4da4c03
 
b066a4d
6bab521
4c6c6ab
b066a4d
6bab521
4c6c6ab
e8ab34d
b066a4d
12f410e
 
b066a4d
6cc2332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e58825b
 
 
603716d
e46d898
6cc2332
603716d
6cc2332
 
 
e46d898
 
 
 
 
 
6cc2332
 
 
12f410e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bab521
 
eaf08d2
4541439
eaf08d2
12b11fa
 
 
6bab521
 
 
 
 
 
 
 
 
 
4541439
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import streamlit as st
import os
import requests

from PIL import Image
#from pydub.playback import Audio
from pydub import AudioSegment


# From transformers import BertModel, BertTokenizer
from transformers import HfAgent, load_tool

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent

# checkpoint = "THUDM/agentlm-7b"
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# agent = LocalAgent(model, tokenizer)
# agent.run("Draw me a picture of rivers and lakes.")

# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!"))

# Load tools
controlnet_transformer = load_tool("Chris4K/random-character-tool")
upscaler = load_tool("Chris4K/text-generation-tool")


tools = [controlnet_transformer, upscaler]

# Define the custom HfAgent class
class CustomHfAgent(Agent):
    def __init__(
        self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None
    ):
        super().__init__(
            chat_prompt_template=chat_prompt_template,
            run_prompt_template=run_prompt_template,
            additional_tools=additional_tools,
        )
        self.url_endpoint = url_endpoint
        self.token = token

    def generate_one(self, prompt, stop):
        headers = {"Authorization": self.token}
        inputs = {
            "inputs": prompt,
            "parameters": {"max_new_tokens": 192, "return_full_text": False, "stop": stop},
        }
        response = requests.post(self.url_endpoint, json=inputs, headers=headers)
        if response.status_code == 429:
            print("Getting rate-limited, waiting a tiny bit before trying again.")
            time.sleep(1)
            return self._generate_one(prompt)
        elif response.status_code != 200:
            raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
        print(response)
        result = response.json()[0]["generated_text"]
        # Inference API returns the stop sequence
        for stop_seq in stop:
            if result.endswith(stop_seq):
                return result[: -len(stop_seq)]
        return result

# Create the Streamlit app
st.title("Hugging Face Agent")

# Input field for the user's message
message = st.text_input("Enter your message:", "")

# Checkboxes for the tools to be used by the agent


tool_checkboxes = [st.checkbox(f"Use {tool.name} --- {tool.description} ") for tool in tools]

# Submit button
#submit_button = st.button("Submit")

# Define the callback function to handle the form submission
def handle_submission():
    # Get the user's message and the selected tools
    #message = st.text_input("Enter your message:", "")
    
    #selected_tools = []
    selected_tools = [tool for idx, tool in enumerate(tools) if tool_checkboxes[idx]]
    print(selected_tools)
    #for tool, checkbox in tool_checkboxes:
    #    if checkbox:
    #        print("checked {tool.name}")
    #        selected_tools.append(tool)
        
    #selected_tools = [tool for tool, checkbox in tool_checkboxes]

    # Initialize the agent
    agent = CustomHfAgent(url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", token=os.environ['HF_token'], additional_tools=selected_tools)

    # Run the agent with the user's message and selected tools
    response = agent.run(message)
    #response = agent.chat(message)

    print(response)

    # Display the agent's response
#    if isinstance(response, str):
#        # Display the text response
#        print("text")
#        st.write(response)
#    elif isinstance(response, Image):
#        # Display the image response
# #       print("image")
#        st.image(response)
#    elif isinstance(response, Audio):
#        print("audio")
#        # Handle audio response (replace with your audio rendering code)
#        st.audio(response)
#    else:
#        # Handle unrecognized response type
#        print("warning")
#        st.warning("Unrecognized response type.")
# Update the import statement for Audio

    # ...
    
    # Display the agent's response
    if response is None:
        st.warning("The agent's response is None.")
    elif isinstance(response, Image.Image):
        # Display the image response
        st.image(response)
    elif "audio" in response:
        # Handle audio response (replace with your audio rendering code)
        audio_data = base64.b64decode(response.split(",")[1])
        audio = AudioSegment.from_file(io.BytesIO(audio_data))
        st.audio(audio)
    elif isinstance(response, AudioSegment):
        # Handle audio response (replace with your audio rendering code)
        st.audio(response)
    elif isinstance(response, str):
        # Display the text response
        st.write(response)
    elif "text" in response:
        # Display the text response
        st.write(response)
    else:
        # Handle unrecognized response type
        st.warning("Unrecognized response type.")


    # Display the agent's response
    # Display the agent's response
    #if response.startswith("Image:"):
    #    # Display the image response
    #    image_data = base64.b64decode(response.split(",")[1])
    #    img = Image.open(io.BytesIO(image_data))
    #    st.image(img)
    #else:
    #    # Display the text response
    #    st.write(response)

# Add a button to trigger the agent to respond again
#st.button("Ask Again
#st.button("Ask Again", key="ask_again_btn")

# Add the callback function to the Streamlit app
submit_button = st.button("Submit", on_click=handle_submission)
#st.button("Ask Again")(handle_submission)

# Define a callback function to handle the button click
def ask_again():
    # Reset the message input field
    message_input.value = ""

    # Run the agent again with an empty message
    agent.run("")

# Add the callback function to the button
#st.button("Ask Again").do(ask_again)