File size: 5,950 Bytes
fdfb0c4 45f17fe d6555d8 d7ca359 d2b3f71 6cc2332 fb5ba89 d7ca359 6bab521 e14d608 b066a4d 3709e0d d6555d8 3709e0d 6bab521 3709e0d 6bab521 4dc413a 6bab521 4dc413a b066a4d bd4ac0b 97b1f2a e8494e9 a20648e e8494e9 713dd39 e8494e9 7155419 7755f96 fb5ba89 7755f96 6bab521 fb5ba89 7755f96 fb5ba89 c218a80 7755f96 c218a80 7755f96 c218a80 7755f96 c218a80 7755f96 e58825b 7755f96 fb5ba89 b066a4d bd4ac0b 5274330 377076c bd4ac0b 377076c 7155419 bd4ac0b 7155419 b066a4d 7155419 b066a4d 7155419 b066a4d 7155419 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import streamlit as st
import os
import requests
from PIL import Image
#from pydub.playback import Audio
from pydub import AudioSegment
import IPython
import soundfile as sf
def play_audio(audio):
sf.write("speech_converted.wav", audio.numpy(), samplerate=16000)
return IPython.display.Audio("speech_converted.wav")
# From transformers import BertModel, BertTokenizer
#from transformers import load_tool
from transformers import HfAgent, load_tool
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent
# checkpoint = "THUDM/agentlm-7b"
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# agent = LocalAgent(model, tokenizer)
# agent.run("Draw me a picture of rivers and lakes.")
# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!"))
# Load tools
random_character_tool = load_tool("Chris4K/random-character-tool")
text_generation_tool = load_tool("Chris4K/text-generation-tool")
#sentiment_tool = load_tool("Chris4K/sentiment-tool")
token_counter_tool = load_tool("Chris4K/token-counter-tool")
most_downloaded_model = load_tool("Chris4K/most-downloaded-model")
#rag_tool = load_tool("Chris4K/rag-tool")
word_counter_tool = load_tool("Chris4K/word-counter-tool")
sentence_counter_tool = load_tool("Chris4K/sentence-counter-tool")
emojify_text_tool = load_tool("Chris4K/EmojifyTextTool")
namedEntityRecognitionTool = load_tool("Chris4K/NamedEntityRecognitionTool")
textDownloadTool = load_tool("Chris4K/TextDownloadTool")
sourcecode_retriever_tool = load_tool("Chris4K/source-code-retriever-tool")
text_to_image = load_tool("Chris4K/text-to-image")
text_to_video = load_tool("Chris4K/text-to-video")
image_transformation = load_tool("Chris4K/image-transformation")
latent_upscaler_tool = load_tool("Chris4K/latent-upscaler-tool")
tools = [random_character_tool, text_generation_tool, sentiment_tool, token_counter_tool, most_downloaded_model, word_counter_tool, sentence_counter_tool, emojify_text_tool , namedEntityRecognitionTool, sourcecode_retriever_tool, text_to_image, text_to_video, image_transformation, latent_upscaler_tool ]
# Define the custom HfAgent class with token and input_params for e.g max_new_token
class CustomHfAgent(Agent):
def __init__(
self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None
):
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
self.url_endpoint = url_endpoint
self.token = token
self.input_params = input_params
def generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
# Use the value from input_params or a default value if not provided
max_new_tokens = self.input_params.get("max_new_tokens", 192)
# Set padding and truncation options
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
inputs = {
"inputs": prompt,
"parameters": parameters,
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
print(response)
result = response.json()[0]["generated_text"]
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result
st.title("Hugging Face Agent and tools")
# Initialize the agent
agent = CustomHfAgent(
url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder",
token=os.environ['HF_token'],
additional_tools=[],
input_params={"max_new_tokens": 192},
)
# Display a welcome message
with st.chat_message("assistant"):
st.markdown("Hello there! How can I assist you today?")
# Input field for the user's message
user_message = st.text_input("User:", key="user_input")
# Checkboxes for the tools to be used by the agent
tool_checkboxes = [st.checkbox(f"Use {tool.name} --- {tool.description} ") for tool in tools]
# Submit button
submit_button = st.button("Submit")
# Define the callback function to handle the form submission
def handle_submission():
selected_tools = [tools[idx] for idx, checkbox in enumerate(tool_checkboxes) if checkbox]
agent.tools = selected_tools
response = agent.chat(user_message)
print("Agent Response\n {}".format(response))
# Display the agent's response
with st.chat_message("assistant"):
if response is None:
st.warning("The agent's response is None. Please try again.")
elif isinstance(response, Image.Image):
st.image(response)
elif "audio" in response:
audio_data = base64.b64decode(response.split(",")[1])
audio = AudioSegment.from_file(io.BytesIO(audio_data))
st.audio(audio)
elif isinstance(response, AudioSegment):
st.audio(response)
elif isinstance(response, str):
st.markdown(response)
elif "text" in response:
st.markdown(response)
else:
st.warning("Unrecognized response type. Please try again.")
# Add the callback function to the Streamlit app
if submit_button:
handle_submission() |