ControlLLM / cllm /services /nlp /llms /chat_models.py
zwgao's picture
add file
3fdcc70
raw
history blame
7.15 kB
import os
import openai
import requests
from typing import (
Any,
Dict,
List,
Optional,
)
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
import os
import sys
sys.path.append(os.getcwd())
from cllm.services.nlp.llms.memory import MessageMemory
from cllm.utils import timeout
class ChatOpenAI:
def __init__(
self,
model_name: str = "gpt-3.5-turbo",
temperature: float = 0.7,
model_kwargs: Dict[str, Any] = dict(),
openai_api_key: Optional[str] = None,
openai_base_url: Optional[str] = None,
) -> None:
self.model_name = model_name
self.temperature = temperature
self.model_kwargs = model_kwargs
self.api_key = os.environ.get("OPENAI_API_KEY", openai_api_key)
self.base_url = os.environ.get("OPENAI_BASE_URL", openai_base_url)
def __call__(self, messages: List[BaseMessage], **kwargs):
stream = kwargs.get("stream", False)
context = MessageMemory(messages=messages)
context.cut_memory(self.model_name)
response = self.send_message(messages=context.to_dict(), stream=stream)
return response
def get_response(self, response):
return response.choices[0].message.content
def send_message(self, messages, stream=False):
cnt = 10
while cnt > 0:
try:
result = self.get_response(
self._send_message(
model=self.model_name,
messages=messages,
temperature=self.temperature,
stream=stream,
timeout=5,
)
)
break
except Exception as e:
cnt -= 1
print(e)
result = e
return result
# @timeout(5)
def _send_message(self, *args, **kwargs):
# return self.client.chat.completions.create(*args, **kwargs)
# return openai.Completion.create(*args, **kwargs)
return openai.chat.completions.create(*args, **kwargs)
class ChatLLAMA2(SimpleChatModel):
"""Wrapper around LLAMA2
To use, you should launch you local model as web services.
"""
client: Any = None #: :meta private:
endpoint: str = "http://localhost:10051"
HUMAN_PROMPT = "user"
AI_PROMPT = "assistant"
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "local-chat"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
data = self._convert_messages_to_prompt(messages)
response = requests.post(self.endpoint, json=data)
return response.content.decode()
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = {
"role": message.role.capitalize(),
"content": message.content,
}
elif isinstance(message, HumanMessage):
message_text = {"role": self.HUMAN_PROMPT, "content": message.content}
elif isinstance(message, AIMessage):
message_text = {"role": self.AI_PROMPT, "content": message.content}
elif isinstance(message, SystemMessage):
message_text = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
"""Format a list of strings into a single string with necessary newlines.
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary newlines.
"""
return [self._convert_one_message_to_text(message) for message in messages]
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
return self._convert_messages_to_text(messages)
class ChatLLAMA2(SimpleChatModel):
"""Wrapper around LLAMA2
To use, you should launch you local model as web services.
"""
client: Any = None #: :meta private:
endpoint: str = "http://localhost:10051"
HUMAN_PROMPT = "user"
AI_PROMPT = "assistant"
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "local-chat"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
data = self._convert_messages_to_prompt(messages)
response = requests.post(self.endpoint, json=data)
return response.content.decode()
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = {
"role": message.role.capitalize(),
"content": message.content,
}
elif isinstance(message, HumanMessage):
message_text = {"role": self.HUMAN_PROMPT, "content": message.content}
elif isinstance(message, AIMessage):
message_text = {"role": self.AI_PROMPT, "content": message.content}
elif isinstance(message, SystemMessage):
message_text = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
"""Format a list of strings into a single string with necessary newlines.
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary newlines.
"""
return [self._convert_one_message_to_text(message) for message in messages]
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
return self._convert_messages_to_text(messages)
if __name__ == "__main__":
chat = ChatOpenAI()
msg = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello!"),
]
response = chat(msg)
print(response)