File size: 6,246 Bytes
fdfb0c4 45f17fe 120d4a1 d7ca359 6cc2332 fb5ba89 6052994 918fcdb 6052994 918fcdb 6052994 b0fad1c 6052994 e55eeac 6052994 e55eeac 6052994 120d4a1 fb5ba89 6052994 120d4a1 6052994 b066a4d 7755f96 120d4a1 7755f96 6bab521 fb5ba89 7755f96 fb5ba89 c218a80 7755f96 c218a80 7755f96 c218a80 7755f96 e58825b 7755f96 120d4a1 6052994 120d4a1 7155419 120d4a1 11e5281 120d4a1 11e5281 ef549d1 11e5281 b066a4d 7155419 cddf298 120d4a1 cddf298 02159ac 120d4a1 cddf298 3892b72 cc5fac8 cddf298 120d4a1 7155419 6052994 7155419 47cf4e1 6052994 7155419 6052994 cddf298 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import streamlit as st
import os
import base64
import io
from PIL import Image
from pydub import AudioSegment
import IPython
import soundfile as sf
import requests
import pandas as pd # If you're working with DataFrames
import matplotlib.figure # If you're using matplotlib figures
import numpy as np
# For Altair charts
import altair as alt
# For Bokeh charts
from bokeh.models import Plot
# For Plotly charts
import plotly.express as px
# For Pydeck charts
import pydeck as pdk
import time
from transformers import load_tool, Agent
import torch
class ToolLoader:
def __init__(self, tool_names):
self.tools = self.load_tools(tool_names)
def load_tools(self, tool_names):
loaded_tools = []
for tool_name in tool_names:
try:
tool = load_tool(tool_name)
loaded_tools.append(tool)
except Exception as e:
print(f"Error loading tool '{tool_name}': {e}")
# Handle the error as needed, e.g., continue with other tools or take corrective action
return loaded_tools
class CustomHfAgent(Agent):
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
self.url_endpoint = url_endpoint
self.token = token
self.input_params = input_params
def generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
max_new_tokens = self.input_params.get("max_new_tokens", 192)
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
inputs = {
"inputs": prompt,
"parameters": parameters,
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
print(response)
result = response.json()[0]["generated_text"]
for stop_seq in stop:
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result
def load_tools(tool_names):
return [load_tool(tool_name) for tool_name in tool_names]
# Define the tool names to load
tool_names = [
"Chris4K/random-character-tool",
"Chris4K/text-generation-tool",
"Chris4K/sentiment-tool",
"Chris4K/EmojifyTextTool",
# Add other tool names as needed
]
# Create tool loader instance
tool_loader = ToolLoader(tool_names)
# Define the callback function to handle the form submission
def handle_submission(user_message, selected_tools):
agent = CustomHfAgent(
url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder",
token=os.environ['HF_token'],
additional_tools=selected_tools,
input_params={"max_new_tokens": 192},
)
response = agent.run(user_message)
print("Agent Response\n {}".format(response))
return response
st.title("Hugging Face Agent and tools")
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
tool_checkboxes = [st.checkbox(f"{tool.name} --- {tool.description} ") for tool in tool_loader.tools]
with st.chat_message("assistant"):
st.markdown("Hello there! How can I assist you today?")
if user_message := st.chat_input("Enter message"):
st.chat_message("user").markdown(user_message)
st.session_state.messages.append({"role": "user", "content": user_message})
selected_tools = [tool_loader.tools[idx] for idx, checkbox in enumerate(tool_checkboxes) if checkbox]
response = handle_submission(user_message, selected_tools)
with st.chat_message("assistant"):
if response is None:
st.warning("The agent's response is None. Please try again. Generate an image of a flying horse.")
elif isinstance(response, Image.Image):
st.image(response)
elif isinstance(response, AudioSegment):
st.audio(response)
elif isinstance(response, int):
st.markdown(response)
elif isinstance(response, str):
if "emojified_text" in response:
st.markdown(f"{response['emojified_text']}")
else:
st.markdown(response)
elif isinstance(response, list):
for item in response:
st.markdown(item) # Assuming the list contains strings
elif isinstance(response, pd.DataFrame):
st.dataframe(response)
elif isinstance(response, pd.Series):
st.table(response.iloc[0:10])
elif isinstance(response, dict):
st.json(response)
elif isinstance(response, streamlit.graphics_altair.AltairChart):
st.altair_chart(response)
elif isinstance(response, streamlit.graphics_bokeh.BokehChart):
st.bokeh_chart(response)
elif isinstance(response, streamlit.graphics_graphviz.GraphvizChart):
st.graphviz_chart(response)
elif isinstance(response, streamlit.graphics_plotly.PlotlyChart):
st.plotly_chart(response)
elif isinstance(response, streamlit.graphics_pydeck.PydeckChart):
st.pydeck_chart(response)
elif isinstance(response, matplotlib.figure.Figure):
st.pyplot(response)
elif isinstance(response, streamlit.graphics_vega_lite.VegaLiteChart):
st.vega_lite_chart(response)
else:
st.warning("Unrecognized response type. Please try again. e.g. Generate an image of a flying horse.")
st.session_state.messages.append({"role": "assistant", "content": response})
|