Chris4K's picture
Update app.py
713dd39 verified
raw
history blame
5.95 kB
import streamlit as st
import os
import requests
from PIL import Image
#from pydub.playback import Audio
from pydub import AudioSegment
import IPython
import soundfile as sf
def play_audio(audio):
sf.write("speech_converted.wav", audio.numpy(), samplerate=16000)
return IPython.display.Audio("speech_converted.wav")
# From transformers import BertModel, BertTokenizer
#from transformers import load_tool
from transformers import HfAgent, load_tool
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent
# checkpoint = "THUDM/agentlm-7b"
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# agent = LocalAgent(model, tokenizer)
# agent.run("Draw me a picture of rivers and lakes.")
# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!"))
# Load tools
random_character_tool = load_tool("Chris4K/random-character-tool")
text_generation_tool = load_tool("Chris4K/text-generation-tool")
#sentiment_tool = load_tool("Chris4K/sentiment-tool")
token_counter_tool = load_tool("Chris4K/token-counter-tool")
most_downloaded_model = load_tool("Chris4K/most-downloaded-model")
#rag_tool = load_tool("Chris4K/rag-tool")
word_counter_tool = load_tool("Chris4K/word-counter-tool")
sentence_counter_tool = load_tool("Chris4K/sentence-counter-tool")
emojify_text_tool = load_tool("Chris4K/EmojifyTextTool")
namedEntityRecognitionTool = load_tool("Chris4K/NamedEntityRecognitionTool")
textDownloadTool = load_tool("Chris4K/TextDownloadTool")
sourcecode_retriever_tool = load_tool("Chris4K/source-code-retriever-tool")
text_to_image = load_tool("Chris4K/text-to-image")
text_to_video = load_tool("Chris4K/text-to-video")
image_transformation = load_tool("Chris4K/image-transformation")
latent_upscaler_tool = load_tool("Chris4K/latent-upscaler-tool")
tools = [random_character_tool, text_generation_tool, sentiment_tool, token_counter_tool, most_downloaded_model, word_counter_tool, sentence_counter_tool, emojify_text_tool , namedEntityRecognitionTool, sourcecode_retriever_tool, text_to_image, text_to_video, image_transformation, latent_upscaler_tool ]
# Define the custom HfAgent class with token and input_params for e.g max_new_token
class CustomHfAgent(Agent):
def __init__(
self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None
):
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
self.url_endpoint = url_endpoint
self.token = token
self.input_params = input_params
def generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
# Use the value from input_params or a default value if not provided
max_new_tokens = self.input_params.get("max_new_tokens", 192)
# Set padding and truncation options
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
inputs = {
"inputs": prompt,
"parameters": parameters,
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
print(response)
result = response.json()[0]["generated_text"]
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result
st.title("Hugging Face Agent and tools")
# Initialize the agent
agent = CustomHfAgent(
url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder",
token=os.environ['HF_token'],
additional_tools=[],
input_params={"max_new_tokens": 192},
)
# Display a welcome message
with st.chat_message("assistant"):
st.markdown("Hello there! How can I assist you today?")
# Input field for the user's message
user_message = st.text_input("User:", key="user_input")
# Checkboxes for the tools to be used by the agent
tool_checkboxes = [st.checkbox(f"Use {tool.name} --- {tool.description} ") for tool in tools]
# Submit button
submit_button = st.button("Submit")
# Define the callback function to handle the form submission
def handle_submission():
selected_tools = [tools[idx] for idx, checkbox in enumerate(tool_checkboxes) if checkbox]
agent.tools = selected_tools
response = agent.chat(user_message)
print("Agent Response\n {}".format(response))
# Display the agent's response
with st.chat_message("assistant"):
if response is None:
st.warning("The agent's response is None. Please try again.")
elif isinstance(response, Image.Image):
st.image(response)
elif "audio" in response:
audio_data = base64.b64decode(response.split(",")[1])
audio = AudioSegment.from_file(io.BytesIO(audio_data))
st.audio(audio)
elif isinstance(response, AudioSegment):
st.audio(response)
elif isinstance(response, str):
st.markdown(response)
elif "text" in response:
st.markdown(response)
else:
st.warning("Unrecognized response type. Please try again.")
# Add the callback function to the Streamlit app
if submit_button:
handle_submission()