File size: 3,708 Bytes
3943768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import asyncio
import multiprocessing
import queue
import threading
import traceback
import typing
from contextlib import contextmanager

from autogen.io import IOStream, OutputStream

from openai_server.agent_utils import filter_kwargs


class CustomOutputStream(OutputStream):
    def print(self, *objects, sep="", end="", flush=False):
        filtered_objects = [x if x not in ["\033[32m", "\033[0m"] else '' for x in objects]
        super().print(*filtered_objects, sep="", end="", flush=flush)

    def dump(self, *objects, sep="", end="", flush=False):
        # Instead of printing, we return objects directly
        return objects


class CustomIOStream(IOStream, CustomOutputStream):
    pass


class CaptureIOStream(IOStream):
    def __init__(self, output_queue: queue.Queue):
        self.output_queue = output_queue

    def print(self, *objects: typing.Any, sep: str = "", end: str = "", flush: bool = True) -> None:
        filtered_objects = [x if x not in ["\033[32m", "\033[0m\n"] else '' for x in objects]
        output = sep.join(map(str, filtered_objects)) + end
        self.output_queue.put(output)


@contextmanager
def capture_iostream(output_queue: queue.Queue) -> typing.Generator[CaptureIOStream, None, None]:
    capture_stream = CaptureIOStream(output_queue)
    with IOStream.set_default(capture_stream):
        yield capture_stream


def run_autogen_in_proc(func, output_queue, result_queue, exception_queue, **kwargs):
    ret_dict = {}
    try:
        # raise ValueError("Testing Error Handling 3")  # works

        with capture_iostream(output_queue):
            ret_dict = func(**kwargs)
            # Signal that agent has finished
            result_queue.put(ret_dict)
    except BaseException as e:
        print(traceback.format_exc())
        exception_queue.put(e)
    finally:
        output_queue.put(None)
        result_queue.put(ret_dict)


async def iostream_generator(func, use_process=False, **kwargs) -> typing.AsyncGenerator[str, None]:
    # start capture
    custom_stream = CustomIOStream()
    IOStream.set_global_default(custom_stream)

    # raise ValueError("Testing Error Handling 2")  #works
    if use_process:
        output_queue = multiprocessing.Queue()
        result_queue = multiprocessing.Queue()
        exception_queue = multiprocessing.Queue()
        proc_cls = multiprocessing.Process
    else:
        output_queue = queue.Queue()
        result_queue = queue.Queue()
        exception_queue = queue.Queue()
        proc_cls = threading.Thread

    # Filter kwargs based on the function signature of run_agent to avoid passing non-picklable things through
    filtered_kwargs = filter_kwargs(func, kwargs)

    # Start agent in a separate thread
    agent_proc = proc_cls(target=run_autogen_in_proc,
                          args=(func, output_queue, result_queue, exception_queue),
                          kwargs=filtered_kwargs)
    agent_proc.start()

    # Yield output as it becomes available
    while True:
        # Check for exceptions
        if not exception_queue.empty():
            e = exception_queue.get()
            raise e
        if not output_queue.empty():
            output = output_queue.get()
            if output is None:  # End of agent execution
                break
            yield output
        await asyncio.sleep(0.005)

    agent_proc.join()

    # Return the final result
    ret_dict = result_queue.get() if not result_queue.empty() else None
    yield ret_dict

    # Return the final result
    if not exception_queue.empty():
        e = exception_queue.get()
        if isinstance(e, SystemExit):
            raise ValueError("SystemExit")
        else:
            raise e