Spaces:
Sleeping
Sleeping
import os | |
import openai | |
import requests | |
from typing import ( | |
Any, | |
Dict, | |
List, | |
Optional, | |
) | |
from langchain.schema import ( | |
AIMessage, | |
BaseMessage, | |
ChatMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.chat_models.base import SimpleChatModel | |
import os | |
import sys | |
sys.path.append(os.getcwd()) | |
from cllm.services.nlp.llms.memory import MessageMemory | |
from cllm.utils import timeout | |
class ChatOpenAI: | |
def __init__( | |
self, | |
model_name: str = "gpt-3.5-turbo", | |
temperature: float = 0.7, | |
model_kwargs: Dict[str, Any] = dict(), | |
openai_api_key: Optional[str] = None, | |
openai_base_url: Optional[str] = None, | |
) -> None: | |
self.model_name = model_name | |
self.temperature = temperature | |
self.model_kwargs = model_kwargs | |
self.api_key = os.environ.get("OPENAI_API_KEY", openai_api_key) | |
self.base_url = os.environ.get("OPENAI_BASE_URL", openai_base_url) | |
def __call__(self, messages: List[BaseMessage], **kwargs): | |
stream = kwargs.get("stream", False) | |
context = MessageMemory(messages=messages) | |
context.cut_memory(self.model_name) | |
response = self.send_message(messages=context.to_dict(), stream=stream) | |
return response | |
def get_response(self, response): | |
return response.choices[0].message.content | |
def send_message(self, messages, stream=False): | |
cnt = 10 | |
while cnt > 0: | |
try: | |
result = self.get_response( | |
self._send_message( | |
model=self.model_name, | |
messages=messages, | |
temperature=self.temperature, | |
stream=stream, | |
timeout=5, | |
) | |
) | |
break | |
except Exception as e: | |
cnt -= 1 | |
print(e) | |
result = e | |
return result | |
# @timeout(5) | |
def _send_message(self, *args, **kwargs): | |
# return self.client.chat.completions.create(*args, **kwargs) | |
# return openai.Completion.create(*args, **kwargs) | |
return openai.chat.completions.create(*args, **kwargs) | |
class ChatLLAMA2(SimpleChatModel): | |
"""Wrapper around LLAMA2 | |
To use, you should launch you local model as web services. | |
""" | |
client: Any = None #: :meta private: | |
endpoint: str = "http://localhost:10051" | |
HUMAN_PROMPT = "user" | |
AI_PROMPT = "assistant" | |
def _llm_type(self) -> str: | |
"""Return type of chat model.""" | |
return "local-chat" | |
def _call( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
) -> str: | |
data = self._convert_messages_to_prompt(messages) | |
response = requests.post(self.endpoint, json=data) | |
return response.content.decode() | |
def _convert_one_message_to_text(self, message: BaseMessage) -> str: | |
if isinstance(message, ChatMessage): | |
message_text = { | |
"role": message.role.capitalize(), | |
"content": message.content, | |
} | |
elif isinstance(message, HumanMessage): | |
message_text = {"role": self.HUMAN_PROMPT, "content": message.content} | |
elif isinstance(message, AIMessage): | |
message_text = {"role": self.AI_PROMPT, "content": message.content} | |
elif isinstance(message, SystemMessage): | |
message_text = {"role": "system", "content": message.content} | |
else: | |
raise ValueError(f"Got unknown type {message}") | |
return message_text | |
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str: | |
"""Format a list of strings into a single string with necessary newlines. | |
Args: | |
messages (List[BaseMessage]): List of BaseMessage to combine. | |
Returns: | |
str: Combined string with necessary newlines. | |
""" | |
return [self._convert_one_message_to_text(message) for message in messages] | |
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str: | |
"""Format a list of messages into a full prompt for the Anthropic model | |
Args: | |
messages (List[BaseMessage]): List of BaseMessage to combine. | |
Returns: | |
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags. | |
""" | |
return self._convert_messages_to_text(messages) | |
class ChatLLAMA2(SimpleChatModel): | |
"""Wrapper around LLAMA2 | |
To use, you should launch you local model as web services. | |
""" | |
client: Any = None #: :meta private: | |
endpoint: str = "http://localhost:10051" | |
HUMAN_PROMPT = "user" | |
AI_PROMPT = "assistant" | |
def _llm_type(self) -> str: | |
"""Return type of chat model.""" | |
return "local-chat" | |
def _call( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
) -> str: | |
data = self._convert_messages_to_prompt(messages) | |
response = requests.post(self.endpoint, json=data) | |
return response.content.decode() | |
def _convert_one_message_to_text(self, message: BaseMessage) -> str: | |
if isinstance(message, ChatMessage): | |
message_text = { | |
"role": message.role.capitalize(), | |
"content": message.content, | |
} | |
elif isinstance(message, HumanMessage): | |
message_text = {"role": self.HUMAN_PROMPT, "content": message.content} | |
elif isinstance(message, AIMessage): | |
message_text = {"role": self.AI_PROMPT, "content": message.content} | |
elif isinstance(message, SystemMessage): | |
message_text = {"role": "system", "content": message.content} | |
else: | |
raise ValueError(f"Got unknown type {message}") | |
return message_text | |
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str: | |
"""Format a list of strings into a single string with necessary newlines. | |
Args: | |
messages (List[BaseMessage]): List of BaseMessage to combine. | |
Returns: | |
str: Combined string with necessary newlines. | |
""" | |
return [self._convert_one_message_to_text(message) for message in messages] | |
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str: | |
"""Format a list of messages into a full prompt for the Anthropic model | |
Args: | |
messages (List[BaseMessage]): List of BaseMessage to combine. | |
Returns: | |
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags. | |
""" | |
return self._convert_messages_to_text(messages) | |
if __name__ == "__main__": | |
chat = ChatOpenAI() | |
msg = [ | |
SystemMessage(content="You are a helpful assistant."), | |
HumanMessage(content="Hello!"), | |
] | |
response = chat(msg) | |
print(response) | |