Spaces:
Sleeping
Sleeping
File size: 5,108 Bytes
e0a73da 8fdc6bf e0a73da fa99d8f e0a73da fa99d8f a0df48e e0a73da a0df48e e0a73da 8fdc6bf fa99d8f e0a73da fa99d8f 02c0ba2 e0a73da 901959b fa99d8f e0a73da fa99d8f a0df48e e0a73da fa99d8f e0a73da fa99d8f 057d3c8 e0a73da fa99d8f e0a73da fa99d8f a0df48e 9237552 bec8a7b e0a73da fa99d8f e0a73da bec8a7b e0a73da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# HF libraries
from langchain_huggingface import HuggingFaceEndpoint
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
# Import things that are needed generically
from langchain.tools.render import render_text_description
from app.schemas.message_schema import (
IChatResponse,
)
from app.utils.utils import generate_uuid
import os
from dotenv import load_dotenv
from app.utils.adaptive_cards.cards import create_adaptive_card
from app.structured_tools.structured_tools import (
arxiv_search, get_arxiv_paper, google_search, wikipedia_search, knowledgeBase_search, memory_search
)
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from langchain.prompts import PromptTemplate
from app.templates.react_json_with_memory import template_system
from app.utils import logger
from app.utils import utils
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from app.utils.callback import (
CustomAsyncCallbackHandler,
CustomFinalStreamingStdOutCallbackHandler,
)
from langchain.memory import ConversationBufferMemory
from app.core.config import settings
local_cache=settings.LOCAL_CACHE
#set_llm_cache(SQLiteCache(database_path=local_cache))
logger = logger.get_console_logger("hf_mixtral_agent")
config = load_dotenv(".env")
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
GOOGLE_CSE_ID = os.getenv('GOOGLE_CSE_ID')
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
LANGCHAIN_TRACING_V2 = "true"
LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com"
LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')
LANGCHAIN_PROJECT = os.getenv('LANGCHAIN_PROJECT')
# GOOGLE_CSE_ID=settings.GOOGLE_CSE_ID
# GOOGLE_API_KEY=settings.GOOGLE_API_KEY
# HUGGINGFACEHUB_API_TOKEN=settings.HUGGINGFACEHUB_API_TOKEN
# print(HUGGINGFACEHUB_API_TOKEN)
router = APIRouter()
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
@router.websocket("/agent")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
if not settings.HUGGINGFACEHUB_API_TOKEN.startswith("hf_"):
await websocket.send_json({"error": "HUGGINGFACEHUB_API_TOKEN is not set"})
return
while True:
try:
data = await websocket.receive_json()
user_message = data["message"]
user_message_card = create_adaptive_card(user_message)
chat_history = []#data["history"]
resp = IChatResponse(
sender="you",
message=user_message_card.to_dict(),
type="start",
message_id=generate_uuid(),
id=generate_uuid(),
)
await websocket.send_json(resp.model_dump())
message_id: str = utils.generate_uuid()
custom_handler = CustomAsyncCallbackHandler(
websocket, message_id=message_id
)
# Load the model from the Hugging Face Hub
llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0.1,
max_new_tokens=1024,
repetition_penalty=1.2,
return_full_text=False
)
tools = [
memory_search,
knowledgeBase_search,
arxiv_search,
wikipedia_search,
google_search,
# get_arxiv_paper,
]
prompt = PromptTemplate.from_template(
template=template_system
)
prompt = prompt.partial(
tools=render_text_description(tools),
tool_names=", ".join([t.name for t in tools]),
)
# define the agent
chat_model_with_stop = llm.bind(stop=["\nObservation"])
agent = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
"chat_history": lambda x: x["chat_history"],
}
| prompt
| chat_model_with_stop
| ReActJsonSingleInputOutputParser()
)
# instantiate AgentExecutor
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
max_iterations=10, # cap number of iterations
#max_execution_time=60, # timout at 60 sec
return_intermediate_steps=True,
handle_parsing_errors=True,
#memory=memory
)
await agent_executor.arun(input=user_message, chat_history=chat_history, callbacks=[custom_handler])
except WebSocketDisconnect:
logger.info("websocket disconnect")
break
|