from typing import Any, Dict, List, Optional, Union from types import GeneratorType from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.embeddings.openai import embed_with_retry, OpenAIEmbeddings from pydantic import Extra, Field, root_validator import numpy as np class StreamingLLMCallbackHandler(AsyncCallbackHandler): """Callback handler for streaming LLM responses to a queue.""" def __init__(self, q): self.q = q def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self.q.put(token) class SyncStreamingLLMCallbackHandler(BaseCallbackHandler): """Callback handler for streaming LLM responses to a queue.""" def __init__(self, q): self.q = q def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Do nothing.""" pass def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self.q.put(token) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Do nothing.""" pass def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: """Do nothing.""" pass def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Do nothing.""" pass def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Do nothing.""" pass def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: """Do nothing.""" pass def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None: """Do nothing.""" pass def on_tool_end( self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: """Do nothing.""" pass def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: """Do nothing.""" pass def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run on agent action.""" pass def on_agent_finish( self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any ) -> None: """Run on agent end.""" pass def concatenate_generators(*args): final_outputs = "" for g in args: if isinstance(g, GeneratorType): for v in g: yield final_outputs + v result = v else: yield final_outputs + g result = g final_outputs += result class CustomOpenAIEmbeddings(OpenAIEmbeddings): model_kwargs: Dict[str, Any] = Field(default_factory=dict) """ A version of OpenAIEmbeddings that allows extra args to be passed to OpenAI functions. Based on langchain's ChatOpenAI. """ class Config: """Configuration for this pydantic object.""" extra = Extra.ignore @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = {field.alias for field in cls.__fields__.values()} extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: # logger.warning( # f"""WARNING! {field_name} is not default parameter. # {field_name} was transferred to model_kwargs. # Please confirm that {field_name} is what you intended.""" # ) extra[field_name] = values.pop(field_name) disallowed_model_kwargs = all_required_field_names | {"model"} invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Instead they were passed in as part of `model_kwargs` parameter." ) values["model_kwargs"] = extra return values # use extra args in calls # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb def _get_len_safe_embeddings( self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None ) -> List[List[float]]: embeddings: List[List[float]] = [[] for _ in range(len(texts))] try: import tiktoken tokens = [] indices = [] encoding = tiktoken.model.encoding_for_model(self.model) for i, text in enumerate(texts): if self.model.endswith("001"): # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") token = encoding.encode( text, allowed_special=self.allowed_special, disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): tokens += [token[j : j + self.embedding_ctx_length]] indices += [i] batched_embeddings = [] _chunk_size = chunk_size or self.chunk_size for i in range(0, len(tokens), _chunk_size): response = embed_with_retry( self, input=tokens[i : i + _chunk_size], engine=self.deployment, request_timeout=self.request_timeout, **self.model_kwargs, ) batched_embeddings += [r["embedding"] for r in response["data"]] results: List[List[List[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i])) for i in range(len(texts)): _result = results[i] if len(_result) == 0: average = embed_with_retry( self, input="", engine=self.deployment, request_timeout=self.request_timeout, **self.model_kwargs, )["data"][0]["embedding"] else: average = np.average( _result, axis=0, weights=num_tokens_in_batch[i] ) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings except ImportError: raise ValueError( "Could not import tiktoken python package. " "This is needed in order to for OpenAIEmbeddings. " "Please install it with `pip install tiktoken`." ) def _embedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to OpenAI's embedding endpoint.""" # handle large input text if len(text) > self.embedding_ctx_length: return self._get_len_safe_embeddings([text], engine=engine)[0] else: if self.model.endswith("001"): # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") return embed_with_retry( self, input=[text], engine=engine, request_timeout=self.request_timeout, **self.model_kwargs, )["data"][0]["embedding"]