|
"""OpenAI chat wrapper.""" |
|
|
|
from __future__ import annotations |
|
|
|
from typing import ( |
|
Any, |
|
AsyncIterator, |
|
Iterator, |
|
List, |
|
Optional, |
|
Union, |
|
) |
|
|
|
from langchain_community.chat_models import ChatOpenAI, AzureChatOpenAI |
|
from langchain_community.chat_models.openai import acompletion_with_retry, _convert_delta_to_message_chunk |
|
from langchain_core.callbacks import ( |
|
AsyncCallbackManagerForLLMRun, |
|
CallbackManagerForLLMRun, |
|
) |
|
from langchain_core.language_models.chat_models import ( |
|
agenerate_from_stream, |
|
generate_from_stream, |
|
) |
|
from langchain_core.messages import ( |
|
AIMessageChunk, |
|
BaseMessage, |
|
) |
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult |
|
from langchain_core.pydantic_v1 import BaseModel |
|
|
|
from langchain_community.adapters.openai import ( |
|
convert_dict_to_message, |
|
) |
|
|
|
|
|
class H2OBaseChatOpenAI: |
|
def _stream( |
|
self, |
|
messages: List[BaseMessage], |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> Iterator[ChatGenerationChunk]: |
|
message_dicts, params = self._create_message_dicts(messages, stop) |
|
params = {**params, **kwargs, "stream": True} |
|
|
|
default_chunk_class = AIMessageChunk |
|
for chunk in self.completion_with_retry( |
|
messages=message_dicts, run_manager=run_manager, **params |
|
): |
|
if not isinstance(chunk, dict): |
|
chunk = chunk.dict() |
|
if len(chunk["choices"]) == 0: |
|
continue |
|
choice = chunk["choices"][0] |
|
chunk = _convert_delta_to_message_chunk( |
|
choice["delta"], default_chunk_class |
|
) |
|
finish_reason = choice.get("finish_reason") |
|
generation_info = ( |
|
dict(finish_reason=finish_reason) if finish_reason is not None else None |
|
) |
|
default_chunk_class = chunk.__class__ |
|
cg_chunk = ChatGenerationChunk( |
|
message=chunk, generation_info=generation_info |
|
) |
|
cg_chunk = self.mod_cg_chunk(cg_chunk) |
|
if run_manager: |
|
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk) |
|
yield cg_chunk |
|
|
|
def mod_cg_chunk(self, cg_chunk: ChatGenerationChunk) -> ChatGenerationChunk: |
|
if 'tools' in self.model_kwargs and self.model_kwargs['tools']: |
|
if 'tool_calls' in cg_chunk.message.additional_kwargs: |
|
cg_chunk.message.content = cg_chunk.text = cg_chunk.message.additional_kwargs['tool_calls'][0]['function']['arguments'] |
|
else: |
|
cg_chunk.text = '' |
|
return cg_chunk |
|
|
|
def _generate( |
|
self, |
|
messages: List[BaseMessage], |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
stream: Optional[bool] = None, |
|
**kwargs: Any, |
|
) -> ChatResult: |
|
should_stream = stream if stream is not None else self.streaming |
|
if should_stream: |
|
stream_iter = self._stream( |
|
messages, stop=stop, run_manager=run_manager, **kwargs |
|
) |
|
return generate_from_stream(stream_iter) |
|
message_dicts, params = self._create_message_dicts(messages, stop) |
|
params = { |
|
**params, |
|
**({"stream": stream} if stream is not None else {}), |
|
**kwargs, |
|
} |
|
response = self.completion_with_retry( |
|
messages=message_dicts, run_manager=run_manager, **params |
|
) |
|
return self._create_chat_result(response) |
|
|
|
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: |
|
generations = [] |
|
if not isinstance(response, dict): |
|
response = response.dict() |
|
for res in response["choices"]: |
|
message = convert_dict_to_message(res["message"]) |
|
|
|
if 'tools' in self.model_kwargs and self.model_kwargs['tools']: |
|
if 'tool_calls' in message.additional_kwargs: |
|
message.content = ''.join([x['function']['arguments'] for x in message.additional_kwargs['tool_calls']]) |
|
|
|
generation_info = dict(finish_reason=res.get("finish_reason")) |
|
if "logprobs" in res: |
|
generation_info["logprobs"] = res["logprobs"] |
|
gen = ChatGeneration( |
|
message=message, |
|
generation_info=generation_info, |
|
) |
|
generations.append(gen) |
|
token_usage = response.get("usage", {}) |
|
llm_output = { |
|
"token_usage": token_usage, |
|
"model_name": self.model_name, |
|
"system_fingerprint": response.get("system_fingerprint", ""), |
|
} |
|
return ChatResult(generations=generations, llm_output=llm_output) |
|
|
|
async def _astream( |
|
self, |
|
messages: List[BaseMessage], |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> AsyncIterator[ChatGenerationChunk]: |
|
message_dicts, params = self._create_message_dicts(messages, stop) |
|
params = {**params, **kwargs, "stream": True} |
|
|
|
default_chunk_class = AIMessageChunk |
|
async for chunk in await acompletion_with_retry( |
|
self, messages=message_dicts, run_manager=run_manager, **params |
|
): |
|
if not isinstance(chunk, dict): |
|
chunk = chunk.dict() |
|
if len(chunk["choices"]) == 0: |
|
continue |
|
choice = chunk["choices"][0] |
|
chunk = _convert_delta_to_message_chunk( |
|
choice["delta"], default_chunk_class |
|
) |
|
finish_reason = choice.get("finish_reason") |
|
generation_info = ( |
|
dict(finish_reason=finish_reason) if finish_reason is not None else None |
|
) |
|
default_chunk_class = chunk.__class__ |
|
cg_chunk = ChatGenerationChunk( |
|
message=chunk, generation_info=generation_info |
|
) |
|
cg_chunk = self.mod_cg_chunk(cg_chunk) |
|
if run_manager: |
|
await run_manager.on_llm_new_token(token=cg_chunk.text, chunk=cg_chunk) |
|
yield cg_chunk |
|
|
|
async def _agenerate( |
|
self, |
|
messages: List[BaseMessage], |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
|
stream: Optional[bool] = None, |
|
**kwargs: Any, |
|
) -> ChatResult: |
|
should_stream = stream if stream is not None else self.streaming |
|
if should_stream: |
|
stream_iter = self._astream( |
|
messages, stop=stop, run_manager=run_manager, **kwargs |
|
) |
|
return await agenerate_from_stream(stream_iter) |
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop) |
|
params = { |
|
**params, |
|
**({"stream": stream} if stream is not None else {}), |
|
**kwargs, |
|
} |
|
response = await acompletion_with_retry( |
|
self, messages=message_dicts, run_manager=run_manager, **params |
|
) |
|
return self._create_chat_result(response) |
|
|
|
|
|
class H2OBaseAzureChatOpenAI(H2OBaseChatOpenAI, AzureChatOpenAI): |
|
pass |
|
|