File size: 5,108 Bytes
e0a73da
8fdc6bf
e0a73da
 
 
 
 
fa99d8f
 
 
 
e0a73da
 
fa99d8f
a0df48e
e0a73da
 
 
 
a0df48e
 
 
e0a73da
8fdc6bf
fa99d8f
 
 
 
 
 
e0a73da
fa99d8f
02c0ba2
e0a73da
 
 
 
 
 
901959b
 
 
 
fa99d8f
 
 
 
e0a73da
 
 
fa99d8f
 
a0df48e
e0a73da
 
fa99d8f
 
 
e0a73da
 
 
 
 
fa99d8f
057d3c8
e0a73da
fa99d8f
 
 
 
 
 
 
e0a73da
fa99d8f
a0df48e
9237552
bec8a7b
 
e0a73da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa99d8f
e0a73da
 
bec8a7b
e0a73da
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# HF libraries
from langchain_huggingface import HuggingFaceEndpoint
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
# Import things that are needed generically
from langchain.tools.render import render_text_description
from app.schemas.message_schema import (
    IChatResponse,
)
from app.utils.utils import generate_uuid
import os
from dotenv import load_dotenv
from app.utils.adaptive_cards.cards import create_adaptive_card
from app.structured_tools.structured_tools import (
    arxiv_search, get_arxiv_paper, google_search, wikipedia_search, knowledgeBase_search, memory_search
)
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from langchain.prompts import PromptTemplate
from app.templates.react_json_with_memory import template_system
from app.utils import logger
from app.utils import utils
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from app.utils.callback import (
    CustomAsyncCallbackHandler,
    CustomFinalStreamingStdOutCallbackHandler,
)
from langchain.memory import ConversationBufferMemory
from app.core.config import settings

local_cache=settings.LOCAL_CACHE
#set_llm_cache(SQLiteCache(database_path=local_cache))
logger = logger.get_console_logger("hf_mixtral_agent")

config = load_dotenv(".env")
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
GOOGLE_CSE_ID = os.getenv('GOOGLE_CSE_ID')
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
LANGCHAIN_TRACING_V2 = "true"
LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com"
LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')
LANGCHAIN_PROJECT = os.getenv('LANGCHAIN_PROJECT')
# GOOGLE_CSE_ID=settings.GOOGLE_CSE_ID
# GOOGLE_API_KEY=settings.GOOGLE_API_KEY
# HUGGINGFACEHUB_API_TOKEN=settings.HUGGINGFACEHUB_API_TOKEN
# print(HUGGINGFACEHUB_API_TOKEN)

router = APIRouter()

memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

@router.websocket("/agent")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    if not settings.HUGGINGFACEHUB_API_TOKEN.startswith("hf_"):
        await websocket.send_json({"error": "HUGGINGFACEHUB_API_TOKEN is not set"})
        return

    while True:
        try:
            data = await websocket.receive_json()
            user_message = data["message"]
            user_message_card = create_adaptive_card(user_message)
            chat_history = []#data["history"]

            resp = IChatResponse(
                sender="you",
                message=user_message_card.to_dict(),
                type="start",
                message_id=generate_uuid(),
                id=generate_uuid(),
            )

            await websocket.send_json(resp.model_dump())
            message_id: str = utils.generate_uuid()
            custom_handler = CustomAsyncCallbackHandler(
                 websocket, message_id=message_id
             )

            # Load the model from the Hugging Face Hub
            llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", 
                                    temperature=0.1, 
                                    max_new_tokens=1024,
                                    repetition_penalty=1.2,
                                    return_full_text=False
                )


            tools = [
                memory_search,
                knowledgeBase_search,
                arxiv_search,
                wikipedia_search,
                google_search,
            #    get_arxiv_paper,
                ]

            prompt = PromptTemplate.from_template(
                template=template_system
            )
            prompt = prompt.partial(
                tools=render_text_description(tools),
                tool_names=", ".join([t.name for t in tools]),
            )


            # define the agent
            chat_model_with_stop = llm.bind(stop=["\nObservation"])
            agent = (
                {
                    "input": lambda x: x["input"],
                    "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
                    "chat_history": lambda x: x["chat_history"],
                }
                | prompt
                | chat_model_with_stop
                | ReActJsonSingleInputOutputParser()
            )

            # instantiate AgentExecutor
            agent_executor = AgentExecutor(
                agent=agent, 
                tools=tools, 
                verbose=True,
                max_iterations=10,       # cap number of iterations
                #max_execution_time=60,  # timout at 60 sec
                return_intermediate_steps=True,
                handle_parsing_errors=True,
                #memory=memory
                )
            
            await agent_executor.arun(input=user_message, chat_history=chat_history, callbacks=[custom_handler])
        except WebSocketDisconnect:
            logger.info("websocket disconnect")
            break