Spaces:
Runtime error
Runtime error
File size: 2,511 Bytes
58d33f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
"""Callback handlers that allow listening to events in LangChain."""
import os
from contextlib import contextmanager
from typing import Generator, Optional
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers import SharedLangChainTracer
from langchain.callbacks.wandb_callback import WandbCallbackHandler
def get_callback_manager() -> BaseCallbackManager:
"""Return the shared callback manager."""
return SharedCallbackManager()
def set_handler(handler: BaseCallbackHandler) -> None:
"""Set handler."""
callback = get_callback_manager()
callback.set_handler(handler)
def set_default_callback_manager() -> None:
"""Set default callback manager."""
default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout")
if default_handler == "stdout":
set_handler(StdOutCallbackHandler())
elif default_handler == "langchain":
session = os.environ.get("LANGCHAIN_SESSION")
set_tracing_callback_manager(session)
else:
raise ValueError(
f"LANGCHAIN_HANDLER should be one of `stdout` "
f"or `langchain`, got {default_handler}"
)
def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
"""Set tracing callback manager."""
handler = SharedLangChainTracer()
callback = get_callback_manager()
callback.set_handlers([handler, StdOutCallbackHandler()])
if session_name is None:
handler.load_default_session()
else:
try:
handler.load_session(session_name)
except Exception:
raise ValueError(f"session {session_name} not found")
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
"""Get OpenAI callback handler in a context manager."""
handler = OpenAICallbackHandler()
manager = get_callback_manager()
manager.add_handler(handler)
yield handler
manager.remove_handler(handler)
__all__ = [
"CallbackManager",
"OpenAICallbackHandler",
"SharedCallbackManager",
"StdOutCallbackHandler",
"WandbCallbackHandler",
"get_openai_callback",
"set_tracing_callback_manager",
"set_default_callback_manager",
"set_handler",
"get_callback_manager",
]
|