|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Example of invisibly propagating a request ID with middleware.""" |
|
|
|
import argparse |
|
import sys |
|
import threading |
|
import uuid |
|
|
|
import pyarrow as pa |
|
import pyarrow.flight as flight |
|
|
|
|
|
class TraceContext: |
|
_locals = threading.local() |
|
_locals.trace_id = None |
|
|
|
@classmethod |
|
def current_trace_id(cls): |
|
if not getattr(cls._locals, "trace_id", None): |
|
cls.set_trace_id(uuid.uuid4().hex) |
|
return cls._locals.trace_id |
|
|
|
@classmethod |
|
def set_trace_id(cls, trace_id): |
|
cls._locals.trace_id = trace_id |
|
|
|
|
|
TRACE_HEADER = "x-tracing-id" |
|
|
|
|
|
class TracingServerMiddleware(flight.ServerMiddleware): |
|
def __init__(self, trace_id): |
|
self.trace_id = trace_id |
|
|
|
def sending_headers(self): |
|
return { |
|
TRACE_HEADER: self.trace_id, |
|
} |
|
|
|
|
|
class TracingServerMiddlewareFactory(flight.ServerMiddlewareFactory): |
|
def start_call(self, info, headers): |
|
print("Starting new call:", info) |
|
if TRACE_HEADER in headers: |
|
trace_id = headers[TRACE_HEADER][0] |
|
print("Found trace header with value:", trace_id) |
|
TraceContext.set_trace_id(trace_id) |
|
return TracingServerMiddleware(TraceContext.current_trace_id()) |
|
|
|
|
|
class TracingClientMiddleware(flight.ClientMiddleware): |
|
def sending_headers(self): |
|
print("Sending trace ID:", TraceContext.current_trace_id()) |
|
return { |
|
"x-tracing-id": TraceContext.current_trace_id(), |
|
} |
|
|
|
def received_headers(self, headers): |
|
if TRACE_HEADER in headers: |
|
trace_id = headers[TRACE_HEADER][0] |
|
print("Found trace header with value:", trace_id) |
|
|
|
|
|
|
|
class TracingClientMiddlewareFactory(flight.ClientMiddlewareFactory): |
|
def start_call(self, info): |
|
print("Starting new call:", info) |
|
return TracingClientMiddleware() |
|
|
|
|
|
class FlightServer(flight.FlightServerBase): |
|
def __init__(self, delegate, **kwargs): |
|
super().__init__(**kwargs) |
|
if delegate: |
|
self.delegate = flight.connect( |
|
delegate, |
|
middleware=(TracingClientMiddlewareFactory(),)) |
|
else: |
|
self.delegate = None |
|
|
|
def list_actions(self, context): |
|
return [ |
|
("get-trace-id", "Get the trace context ID."), |
|
] |
|
|
|
def do_action(self, context, action): |
|
trace_middleware = context.get_middleware("trace") |
|
if trace_middleware: |
|
TraceContext.set_trace_id(trace_middleware.trace_id) |
|
if action.type == "get-trace-id": |
|
if self.delegate: |
|
for result in self.delegate.do_action(action): |
|
yield result |
|
else: |
|
trace_id = TraceContext.current_trace_id().encode("utf-8") |
|
print("Returning trace ID:", trace_id) |
|
buf = pa.py_buffer(trace_id) |
|
yield pa.flight.Result(buf) |
|
else: |
|
raise KeyError(f"Unknown action {action.type!r}") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
|
|
subparsers = parser.add_subparsers(dest="command") |
|
client = subparsers.add_parser("client", help="Run the client.") |
|
client.add_argument("server") |
|
client.add_argument("--request-id", default=None) |
|
|
|
server = subparsers.add_parser("server", help="Run the server.") |
|
server.add_argument( |
|
"--listen", |
|
required=True, |
|
help="The location to listen on (example: grpc://localhost:5050)", |
|
) |
|
server.add_argument( |
|
"--delegate", |
|
required=False, |
|
default=None, |
|
help=("A location to delegate to. That is, this server will " |
|
"simply call the given server for the response. Demonstrates " |
|
"propagation of the trace ID between servers."), |
|
) |
|
|
|
args = parser.parse_args() |
|
if not getattr(args, "command"): |
|
parser.print_help() |
|
return 1 |
|
|
|
if args.command == "server": |
|
server = FlightServer( |
|
args.delegate, |
|
location=args.listen, |
|
middleware={"trace": TracingServerMiddlewareFactory()}) |
|
server.serve() |
|
elif args.command == "client": |
|
client = flight.connect( |
|
args.server, |
|
middleware=(TracingClientMiddlewareFactory(),)) |
|
if args.request_id: |
|
TraceContext.set_trace_id(args.request_id) |
|
else: |
|
TraceContext.set_trace_id("client-chosen-id") |
|
|
|
for result in client.do_action(flight.Action("get-trace-id", b"")): |
|
print(result.body.to_pybytes()) |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.exit(main() or 0) |
|
|