doc-search / util.py
briankchan's picture
Initial commit
2afc3f4
from typing import Any, Dict, List, Optional, Union
from types import GeneratorType
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.embeddings.openai import embed_with_retry, OpenAIEmbeddings
from pydantic import Extra, Field, root_validator
import numpy as np
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses to a queue."""
def __init__(self, q):
self.q = q
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.q.put(token)
class SyncStreamingLLMCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming LLM responses to a queue."""
def __init__(self, q):
self.q = q
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.q.put(token)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
pass
def concatenate_generators(*args):
final_outputs = ""
for g in args:
if isinstance(g, GeneratorType):
for v in g:
yield final_outputs + v
result = v
else:
yield final_outputs + g
result = g
final_outputs += result
class CustomOpenAIEmbeddings(OpenAIEmbeddings):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""
A version of OpenAIEmbeddings that allows extra args
to be passed to OpenAI functions.
Based on langchain's ChatOpenAI.
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.ignore
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
# logger.warning(
# f"""WARNING! {field_name} is not default parameter.
# {field_name} was transferred to model_kwargs.
# Please confirm that {field_name} is what you intended."""
# )
extra[field_name] = values.pop(field_name)
disallowed_model_kwargs = all_required_field_names | {"model"}
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
# use extra args in calls
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]:
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
try:
import tiktoken
tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.model)
for i, text in enumerate(texts):
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
token = encoding.encode(
text,
allowed_special=self.allowed_special,
disallowed_special=self.disallowed_special,
)
for j in range(0, len(token), self.embedding_ctx_length):
tokens += [token[j : j + self.embedding_ctx_length]]
indices += [i]
batched_embeddings = []
_chunk_size = chunk_size or self.chunk_size
for i in range(0, len(tokens), _chunk_size):
response = embed_with_retry(
self,
input=tokens[i : i + _chunk_size],
engine=self.deployment,
request_timeout=self.request_timeout,
**self.model_kwargs,
)
batched_embeddings += [r["embedding"] for r in response["data"]]
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
for i in range(len(texts)):
_result = results[i]
if len(_result) == 0:
average = embed_with_retry(
self,
input="",
engine=self.deployment,
request_timeout=self.request_timeout,
**self.model_kwargs,
)["data"][0]["embedding"]
else:
average = np.average(
_result, axis=0, weights=num_tokens_in_batch[i]
)
embeddings[i] = (average / np.linalg.norm(average)).tolist()
return embeddings
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to for OpenAIEmbeddings. "
"Please install it with `pip install tiktoken`."
)
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# handle large input text
if len(text) > self.embedding_ctx_length:
return self._get_len_safe_embeddings([text], engine=engine)[0]
else:
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(
self, input=[text], engine=engine, request_timeout=self.request_timeout,
**self.model_kwargs,
)["data"][0]["embedding"]