aiben / src /langchain_openai_local.py
abugaber's picture
Upload folder using huggingface_hub
3943768 verified
raw
history blame
7.32 kB
"""OpenAI chat wrapper."""
from __future__ import annotations
from typing import (
Any,
AsyncIterator,
Iterator,
List,
Optional,
Union,
)
from langchain_community.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain_community.chat_models.openai import acompletion_with_retry, _convert_delta_to_message_chunk
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_community.adapters.openai import (
convert_dict_to_message,
)
class H2OBaseChatOpenAI:
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
cg_chunk = self.mod_cg_chunk(cg_chunk)
if run_manager:
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
yield cg_chunk
def mod_cg_chunk(self, cg_chunk: ChatGenerationChunk) -> ChatGenerationChunk:
if 'tools' in self.model_kwargs and self.model_kwargs['tools']:
if 'tool_calls' in cg_chunk.message.additional_kwargs:
cg_chunk.message.content = cg_chunk.text = cg_chunk.message.additional_kwargs['tool_calls'][0]['function']['arguments']
else:
cg_chunk.text = ''
return cg_chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
message = convert_dict_to_message(res["message"])
if 'tools' in self.model_kwargs and self.model_kwargs['tools']:
if 'tool_calls' in message.additional_kwargs:
message.content = ''.join([x['function']['arguments'] for x in message.additional_kwargs['tool_calls']])
generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(
message=message,
generation_info=generation_info,
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
cg_chunk = self.mod_cg_chunk(cg_chunk)
if run_manager:
await run_manager.on_llm_new_token(token=cg_chunk.text, chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
class H2OBaseAzureChatOpenAI(H2OBaseChatOpenAI, AzureChatOpenAI):
pass