File size: 2,511 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Callback handlers that allow listening to events in LangChain."""
import os
from contextlib import contextmanager
from typing import Generator, Optional

from langchain.callbacks.base import (
    BaseCallbackHandler,
    BaseCallbackManager,
    CallbackManager,
)
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers import SharedLangChainTracer
from langchain.callbacks.wandb_callback import WandbCallbackHandler


def get_callback_manager() -> BaseCallbackManager:
    """Return the shared callback manager."""
    return SharedCallbackManager()


def set_handler(handler: BaseCallbackHandler) -> None:
    """Set handler."""
    callback = get_callback_manager()
    callback.set_handler(handler)


def set_default_callback_manager() -> None:
    """Set default callback manager."""
    default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout")
    if default_handler == "stdout":
        set_handler(StdOutCallbackHandler())
    elif default_handler == "langchain":
        session = os.environ.get("LANGCHAIN_SESSION")
        set_tracing_callback_manager(session)
    else:
        raise ValueError(
            f"LANGCHAIN_HANDLER should be one of `stdout` "
            f"or `langchain`, got {default_handler}"
        )


def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
    """Set tracing callback manager."""
    handler = SharedLangChainTracer()
    callback = get_callback_manager()
    callback.set_handlers([handler, StdOutCallbackHandler()])
    if session_name is None:
        handler.load_default_session()
    else:
        try:
            handler.load_session(session_name)
        except Exception:
            raise ValueError(f"session {session_name} not found")


@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
    """Get OpenAI callback handler in a context manager."""
    handler = OpenAICallbackHandler()
    manager = get_callback_manager()
    manager.add_handler(handler)
    yield handler
    manager.remove_handler(handler)


__all__ = [
    "CallbackManager",
    "OpenAICallbackHandler",
    "SharedCallbackManager",
    "StdOutCallbackHandler",
    "WandbCallbackHandler",
    "get_openai_callback",
    "set_tracing_callback_manager",
    "set_default_callback_manager",
    "set_handler",
    "get_callback_manager",
]