File size: 6,246 Bytes
fdfb0c4
45f17fe
120d4a1
 
d7ca359
6cc2332
fb5ba89
 
6052994
 
 
918fcdb
6052994
 
918fcdb
 
6052994
b0fad1c
6052994
 
e55eeac
6052994
 
e55eeac
6052994
 
 
120d4a1
 
fb5ba89
6052994
120d4a1
 
6052994
 
 
 
 
 
 
 
 
 
 
 
 
b066a4d
7755f96
120d4a1
7755f96
 
 
 
 
6bab521
 
fb5ba89
7755f96
 
 
fb5ba89
c218a80
7755f96
 
c218a80
7755f96
 
c218a80
7755f96
 
 
 
 
 
e58825b
7755f96
 
 
 
 
 
120d4a1
 
 
 
 
 
 
6052994
 
 
120d4a1
 
 
 
 
 
7155419
120d4a1
11e5281
120d4a1
11e5281
 
ef549d1
11e5281
 
 
b066a4d
7155419
cddf298
120d4a1
cddf298
 
 
 
 
 
 
 
 
02159ac
120d4a1
 
cddf298
 
 
3892b72
cc5fac8
 
cddf298
120d4a1
 
 
7155419
 
6052994
7155419
 
 
 
47cf4e1
 
6052994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7155419
6052994
 
cddf298
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import streamlit as st
import os
import base64
import io
from PIL import Image
from pydub import AudioSegment
import IPython
import soundfile as sf
import requests
import pandas as pd  # If you're working with DataFrames
import matplotlib.figure  # If you're using matplotlib figures
import numpy as np

# For Altair charts
import altair as alt

# For Bokeh charts
from bokeh.models import Plot

# For Plotly charts
import plotly.express as px

# For Pydeck charts
import pydeck as pdk


import time  
from transformers import load_tool, Agent
import torch


class ToolLoader:
    def __init__(self, tool_names):
        self.tools = self.load_tools(tool_names)

    def load_tools(self, tool_names):
        loaded_tools = []
        for tool_name in tool_names:
            try:
                tool = load_tool(tool_name)
                loaded_tools.append(tool)
            except Exception as e:
                print(f"Error loading tool '{tool_name}': {e}")
                # Handle the error as needed, e.g., continue with other tools or take corrective action

        return loaded_tools

class CustomHfAgent(Agent):
    def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
        super().__init__(
            chat_prompt_template=chat_prompt_template,
            run_prompt_template=run_prompt_template,
            additional_tools=additional_tools,
        )
        self.url_endpoint = url_endpoint
        self.token = token
        self.input_params = input_params

    def generate_one(self, prompt, stop):
        headers = {"Authorization": self.token}
        max_new_tokens = self.input_params.get("max_new_tokens", 192)
        parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
        inputs = {
            "inputs": prompt,
            "parameters": parameters,
        }
        response = requests.post(self.url_endpoint, json=inputs, headers=headers)

        if response.status_code == 429:
            print("Getting rate-limited, waiting a tiny bit before trying again.")
            time.sleep(1)
            return self._generate_one(prompt)
        elif response.status_code != 200:
            raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
        print(response)
        result = response.json()[0]["generated_text"]
        for stop_seq in stop:
            if result.endswith(stop_seq):
                return result[: -len(stop_seq)]
        return result

def load_tools(tool_names):
    return [load_tool(tool_name) for tool_name in tool_names]

# Define the tool names to load
tool_names = [
    "Chris4K/random-character-tool",
    "Chris4K/text-generation-tool",
    "Chris4K/sentiment-tool",
    "Chris4K/EmojifyTextTool",
    
    # Add other tool names as needed
]

# Create tool loader instance
tool_loader = ToolLoader(tool_names)

# Define the callback function to handle the form submission
def handle_submission(user_message, selected_tools):
    agent = CustomHfAgent(
        url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder",
        token=os.environ['HF_token'],
        additional_tools=selected_tools,
        input_params={"max_new_tokens": 192},
    )

    response = agent.run(user_message)

    print("Agent Response\n {}".format(response))

    return response

st.title("Hugging Face Agent and tools")

if "messages" not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

tool_checkboxes = [st.checkbox(f"{tool.name} --- {tool.description} ") for tool in tool_loader.tools]

with st.chat_message("assistant"):
    st.markdown("Hello there! How can I assist you today?")

if user_message := st.chat_input("Enter message"):
    st.chat_message("user").markdown(user_message)
    st.session_state.messages.append({"role": "user", "content": user_message})

    selected_tools = [tool_loader.tools[idx] for idx, checkbox in enumerate(tool_checkboxes) if checkbox]
    response = handle_submission(user_message, selected_tools)

    with st.chat_message("assistant"):
        if response is None:
            st.warning("The agent's response is None. Please try again. Generate an image of a flying horse.")
        elif isinstance(response, Image.Image):
            st.image(response)
        elif isinstance(response, AudioSegment):
            st.audio(response)
        elif isinstance(response, int):
            st.markdown(response)
        elif isinstance(response, str):
            if "emojified_text" in response:
                st.markdown(f"{response['emojified_text']}")
            else:
                st.markdown(response)
        elif isinstance(response, list):
            for item in response:
                st.markdown(item)  # Assuming the list contains strings
        elif isinstance(response, pd.DataFrame):
            st.dataframe(response)
        elif isinstance(response, pd.Series):
            st.table(response.iloc[0:10])
        elif isinstance(response, dict):
            st.json(response)
        elif isinstance(response, streamlit.graphics_altair.AltairChart):
            st.altair_chart(response)
        elif isinstance(response, streamlit.graphics_bokeh.BokehChart):
            st.bokeh_chart(response)
        elif isinstance(response, streamlit.graphics_graphviz.GraphvizChart):
            st.graphviz_chart(response)
        elif isinstance(response, streamlit.graphics_plotly.PlotlyChart):
            st.plotly_chart(response)
        elif isinstance(response, streamlit.graphics_pydeck.PydeckChart):
            st.pydeck_chart(response)
        elif isinstance(response, matplotlib.figure.Figure):
            st.pyplot(response)
        elif isinstance(response, streamlit.graphics_vega_lite.VegaLiteChart):
            st.vega_lite_chart(response)
        else:
            st.warning("Unrecognized response type. Please try again. e.g. Generate an image of a flying horse.")

            
    st.session_state.messages.append({"role": "assistant", "content": response})