Spaces:
Sleeping
Sleeping
Suraj Yadav
commited on
Commit
·
0a518ff
1
Parent(s):
67dc3b9
Implemented tool calling chatbot node [skip ci]
Browse files
src/basicchatbot/llms/__init__.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import streamlit as st
|
2 |
from typing import Dict, Optional, Type, Union
|
3 |
-
|
4 |
from src.basicchatbot.llms.base_llm import BaseLLMProvider
|
5 |
from src.basicchatbot.llms.groq_llm import GroqLLM
|
6 |
from src.basicchatbot.llms.openai_llm import OpenAILLM
|
7 |
|
8 |
|
9 |
-
def get_llm(llm_name: str, user_input: Dict[str, str]) -> Optional[
|
10 |
"""
|
11 |
Function to get the appropriate LLM model instance.
|
12 |
|
|
|
1 |
import streamlit as st
|
2 |
from typing import Dict, Optional, Type, Union
|
3 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
4 |
from src.basicchatbot.llms.base_llm import BaseLLMProvider
|
5 |
from src.basicchatbot.llms.groq_llm import GroqLLM
|
6 |
from src.basicchatbot.llms.openai_llm import OpenAILLM
|
7 |
|
8 |
|
9 |
+
def get_llm(llm_name: str, user_input: Dict[str, str]) -> Optional[BaseChatModel]:
|
10 |
"""
|
11 |
Function to get the appropriate LLM model instance.
|
12 |
|
src/basicchatbot/llms/groq_llm.py
CHANGED
@@ -2,12 +2,13 @@ import os
|
|
2 |
import streamlit as st
|
3 |
from typing import Dict, Optional
|
4 |
from langchain_groq import ChatGroq
|
|
|
5 |
from src.basicchatbot.llms.base_llm import BaseLLMProvider
|
6 |
|
7 |
|
8 |
class GroqLLM(BaseLLMProvider):
|
9 |
|
10 |
-
def get_llm_model(self) -> Optional[
|
11 |
try:
|
12 |
# Clear previous error messages
|
13 |
self.error_messages = []
|
|
|
2 |
import streamlit as st
|
3 |
from typing import Dict, Optional
|
4 |
from langchain_groq import ChatGroq
|
5 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
6 |
from src.basicchatbot.llms.base_llm import BaseLLMProvider
|
7 |
|
8 |
|
9 |
class GroqLLM(BaseLLMProvider):
|
10 |
|
11 |
+
def get_llm_model(self) -> Optional[BaseChatModel]:
|
12 |
try:
|
13 |
# Clear previous error messages
|
14 |
self.error_messages = []
|
src/basicchatbot/llms/openai_llm.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
from typing import Dict, Optional
|
|
|
4 |
from langchain_openai import ChatOpenAI
|
5 |
from src.basicchatbot.llms.base_llm import BaseLLMProvider
|
6 |
|
7 |
class OpenAILLM(BaseLLMProvider):
|
8 |
|
9 |
-
def get_llm_model(self) -> Optional[
|
10 |
try:
|
11 |
# Clear previous error messages
|
12 |
self.error_messages = []
|
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
from typing import Dict, Optional
|
4 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
5 |
from langchain_openai import ChatOpenAI
|
6 |
from src.basicchatbot.llms.base_llm import BaseLLMProvider
|
7 |
|
8 |
class OpenAILLM(BaseLLMProvider):
|
9 |
|
10 |
+
def get_llm_model(self) -> Optional[BaseChatModel]:
|
11 |
try:
|
12 |
# Clear previous error messages
|
13 |
self.error_messages = []
|
src/basicchatbot/nodes/__init__.py
CHANGED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from src.basicchatbot.nodes.basic_chatbot_node import BasicChatBotNode
|
2 |
+
from src.basicchatbot.nodes.chatbot_with_tools import ChatbotWithToolsNode
|
src/basicchatbot/nodes/chatbot_with_tools.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List
|
2 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
3 |
+
from langchain_core.messages import AIMessage
|
4 |
+
from src.basicchatbot.state.state import BasicChatBotState
|
5 |
+
|
6 |
+
|
7 |
+
class ChatbotWithToolsNode:
|
8 |
+
"""Handles chatbot interactions using an LLM and associated tools."""
|
9 |
+
|
10 |
+
def __init__(self, model: BaseChatModel, tools: List[Any]) -> None:
|
11 |
+
"""
|
12 |
+
Initialize the chatbot node with a model and tools.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
model (BaseChatModel): The language model used for processing messages.
|
16 |
+
tools (List[Any]): A list of tools that can be used with the chatbot.
|
17 |
+
"""
|
18 |
+
self.llm = model
|
19 |
+
self.tools = tools
|
20 |
+
|
21 |
+
def node(self, state: BasicChatBotState) -> dict:
|
22 |
+
"""
|
23 |
+
Processes the chatbot state and generates a response.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
state (BasicChatBotState): The current chatbot state containing messages.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
dict: A dictionary containing the chatbot's response messages.
|
30 |
+
"""
|
31 |
+
try:
|
32 |
+
messages = state.get("messages", [])
|
33 |
+
if not messages:
|
34 |
+
return {"messages": [AIMessage(content="ERROR: `messages` key is missing in the state. Contact developer to fix.")]}
|
35 |
+
|
36 |
+
response = self.llm.bind_tools(self.tools).invoke(input=messages)
|
37 |
+
|
38 |
+
return {"messages": [response] }
|
39 |
+
|
40 |
+
except Exception as e:
|
41 |
+
return {"messages": [AIMessage(content=f"Error processing request: {str(e)}")]}
|