|
import streamlit as st |
|
import os |
|
import requests |
|
|
|
from PIL import Image |
|
|
|
from pydub import AudioSegment |
|
|
|
import IPython |
|
import soundfile as sf |
|
|
|
def play_audio(audio): |
|
sf.write("speech_converted.wav", audio.numpy(), samplerate=16000) |
|
return IPython.display.Audio("speech_converted.wav") |
|
|
|
|
|
|
|
|
|
from transformers import HfAgent, load_tool |
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
random_character_tool = load_tool("Chris4K/random-character-tool") |
|
text_generation_tool = load_tool("Chris4K/text-generation-tool") |
|
|
|
token_counter_tool = load_tool("Chris4K/token-counter-tool") |
|
most_downloaded_model = load_tool("Chris4K/most-downloaded-model") |
|
|
|
word_counter_tool = load_tool("Chris4K/word-counter-tool") |
|
sentence_counter_tool = load_tool("Chris4K/sentence-counter-tool") |
|
emojify_text_tool = load_tool("Chris4K/EmojifyTextTool") |
|
namedEntityRecognitionTool = load_tool("Chris4K/NamedEntityRecognitionTool") |
|
textDownloadTool = load_tool("Chris4K/TextDownloadTool") |
|
sourcecode_retriever_tool = load_tool("Chris4K/source-code-retriever-tool") |
|
|
|
text_to_image = load_tool("Chris4K/text-to-image") |
|
text_to_video = load_tool("Chris4K/text-to-video") |
|
image_transformation = load_tool("Chris4K/image-transformation") |
|
latent_upscaler_tool = load_tool("Chris4K/latent-upscaler-tool") |
|
|
|
tools = [random_character_tool, text_generation_tool, sentiment_tool, token_counter_tool, most_downloaded_model, word_counter_tool, sentence_counter_tool, emojify_text_tool , namedEntityRecognitionTool, sourcecode_retriever_tool, text_to_image, text_to_video, image_transformation, latent_upscaler_tool ] |
|
|
|
|
|
class CustomHfAgent(Agent): |
|
def __init__( |
|
self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None |
|
): |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
self.url_endpoint = url_endpoint |
|
self.token = token |
|
self.input_params = input_params |
|
|
|
def generate_one(self, prompt, stop): |
|
headers = {"Authorization": self.token} |
|
|
|
max_new_tokens = self.input_params.get("max_new_tokens", 192) |
|
|
|
|
|
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True} |
|
|
|
inputs = { |
|
"inputs": prompt, |
|
"parameters": parameters, |
|
} |
|
|
|
response = requests.post(self.url_endpoint, json=inputs, headers=headers) |
|
|
|
if response.status_code == 429: |
|
print("Getting rate-limited, waiting a tiny bit before trying again.") |
|
time.sleep(1) |
|
return self._generate_one(prompt) |
|
elif response.status_code != 200: |
|
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") |
|
print(response) |
|
result = response.json()[0]["generated_text"] |
|
|
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
return result[: -len(stop_seq)] |
|
return result |
|
|
|
st.title("Hugging Face Agent and tools") |
|
|
|
|
|
agent = CustomHfAgent( |
|
url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", |
|
token=os.environ['HF_token'], |
|
additional_tools=[], |
|
input_params={"max_new_tokens": 192}, |
|
) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
st.markdown("Hello there! How can I assist you today?") |
|
|
|
|
|
user_message = st.text_input("User:", key="user_input") |
|
|
|
|
|
tool_checkboxes = [st.checkbox(f"Use {tool.name} --- {tool.description} ") for tool in tools] |
|
|
|
|
|
submit_button = st.button("Submit") |
|
|
|
|
|
def handle_submission(): |
|
selected_tools = [tools[idx] for idx, checkbox in enumerate(tool_checkboxes) if checkbox] |
|
agent.tools = selected_tools |
|
|
|
response = agent.chat(user_message) |
|
|
|
print("Agent Response\n {}".format(response)) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
if response is None: |
|
st.warning("The agent's response is None. Please try again.") |
|
elif isinstance(response, Image.Image): |
|
st.image(response) |
|
elif "audio" in response: |
|
audio_data = base64.b64decode(response.split(",")[1]) |
|
audio = AudioSegment.from_file(io.BytesIO(audio_data)) |
|
st.audio(audio) |
|
elif isinstance(response, AudioSegment): |
|
st.audio(response) |
|
elif isinstance(response, str): |
|
st.markdown(response) |
|
elif "text" in response: |
|
st.markdown(response) |
|
else: |
|
st.warning("Unrecognized response type. Please try again.") |
|
|
|
|
|
if submit_button: |
|
handle_submission() |