File size: 5,473 Bytes
fdfb0c4
45f17fe
d6555d8
 
d7ca359
d2b3f71
6cc2332
 
fb5ba89
 
 
 
 
 
 
d7ca359
6bab521
b066a4d
 
3709e0d
d6555d8
3709e0d
6bab521
 
 
3709e0d
6bab521
 
4dc413a
6bab521
4dc413a
b066a4d
bd4ac0b
 
b066a4d
7755f96
bd4ac0b
7755f96
6bab521
7755f96
 
fb5ba89
7755f96
 
 
 
 
 
6bab521
 
fb5ba89
7755f96
 
 
fb5ba89
 
7755f96
 
fb5ba89
 
7755f96
 
 
 
 
 
 
 
e58825b
7755f96
 
 
 
 
 
 
fb5ba89
b066a4d
 
e0fddbb
b066a4d
 
4da4c03
e58825b
 
b066a4d
 
8fe5a03
b066a4d
bd4ac0b
 
 
 
 
 
 
 
 
b066a4d
 
 
e0fddbb
4da4c03
dfe0aa8
 
bd4ac0b
dfe0aa8
b066a4d
bd4ac0b
b066a4d
bd4ac0b
6bab521
bd4ac0b
 
b066a4d
fb5ba89
6cc2332
 
e58825b
fb5ba89
e58825b
603716d
e46d898
6cc2332
603716d
6cc2332
 
 
e46d898
 
 
 
 
 
6cc2332
 
 
12f410e
 
fb5ba89
12f410e
 
eaf08d2
12b11fa
 
6bab521
 
 
 
 
 
 
 
 
 
fb5ba89
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
import os
import requests

from PIL import Image
#from pydub.playback import Audio
from pydub import AudioSegment

import IPython
import soundfile as sf

def play_audio(audio):
    sf.write("speech_converted.wav", audio.numpy(), samplerate=16000)
    return IPython.display.Audio("speech_converted.wav")
    

# From transformers import BertModel, BertTokenizer
from transformers import HfAgent, load_tool

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent

# checkpoint = "THUDM/agentlm-7b"
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# agent = LocalAgent(model, tokenizer)
# agent.run("Draw me a picture of rivers and lakes.")

# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!"))

# Load tools
random_character_tool = load_tool("Chris4K/random-character-tool")
text_generation_tool = load_tool("Chris4K/text-generation-tool")


tools = [random_character_tool, text_generation_tool]

# Define the custom HfAgent class
class CustomHfAgent(Agent):
    def __init__(
        self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None
    ):
        super().__init__(
            chat_prompt_template=chat_prompt_template,
            run_prompt_template=run_prompt_template,
            additional_tools=additional_tools,
        )
        self.url_endpoint = url_endpoint
        self.token = token
        self.input_params = input_params

    def generate_one(self, prompt, stop):
        headers = {"Authorization": self.token}
        # Use the value from input_params or a default value if not provided
        max_new_tokens = self.input_params.get("max_new_tokens", 192)
        inputs = {
            "inputs": prompt,
            # Here the max_new_token varies from default 200 which leads to an error
            "parameters": {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop},
        }
        response = requests.post(self.url_endpoint, json=inputs, headers=headers)
        if response.status_code == 429:
            print("Getting rate-limited, waiting a tiny bit before trying again.")
            time.sleep(1)
            return self._generate_one(prompt)
        elif response.status_code != 200:
            raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
        print(response)
        result = response.json()[0]["generated_text"]
        # Inference API returns the stop sequence
        for stop_seq in stop:
            if result.endswith(stop_seq):
                return result[: -len(stop_seq)]
        return result

st.title("Hugging Face Agent and tools")

# Input field for the user's message
message = st.text_input("Enter your message:", "")

# Checkboxes for the tools to be used by the agent


tool_checkboxes = [st.checkbox(f"Use {tool.name} --- {tool.description} ") for tool in tools]

# Submit button
#submit_button = st.button("Submit")

# Initialize the agent
agent = CustomHfAgent(
    url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder",
    token=os.environ['HF_token'],
    additional_tools=selected_tools,
    input_params={"max_new_tokens": 192},  # Set the desired value
)


# Define the callback function to handle the form submission
def handle_submission():
    # Get the user's message and the selected tools
    #message = st.text_input("Enter your message:", "")
    
    #selected_tools = []
    selected_tools = [tool for idx, tool in enumerate(tools) if tool_checkboxes[idx]]

    print(selected_tools)

    agent.tools = selected_tools

    
    # Run the agent with the user's message and selected tools
    #response = agent.run(message)
    response = agent.chat(message)

    print("Response " + response)
    
    # Display the agent's response
    if response is None:
        st.warning("The agent's response is None.  Please try again. For Example: Generate an image of a boat in the water")
    elif isinstance(response, Image.Image):
        # Display the image response
        st.image(response)
    elif "audio" in response:
        # Handle audio response (replace with your audio rendering code)
        audio_data = base64.b64decode(response.split(",")[1])
        audio = AudioSegment.from_file(io.BytesIO(audio_data))
        st.audio(audio)
    elif isinstance(response, AudioSegment):
        # Handle audio response (replace with your audio rendering code)
        st.audio(response)
    elif isinstance(response, str):
        # Display the text response
        st.write(response)
    elif "text" in response:
        # Display the text response
        st.write(response)
    else:
        # Handle unrecognized response type
        st.warning("Unrecognized response type. Please try again. For Example: Generate an image of a boat in the water")



# Add the callback function to the Streamlit app
submit_button = st.button("Submit", on_click=handle_submission)

# Define a callback function to handle the button click
def ask_again():
    # Reset the message input field
    message_input.value = ""

    # Run the agent again with an empty message
    agent.run("")

# Add the callback function to the button
ask_again = st.button("Ask again", on_click=ask_again)