File size: 3,995 Bytes
d1ceb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import os.path
import socket
import socketserver
import threading
from contextlib import closing, contextmanager
from http.server import SimpleHTTPRequestHandler
from typing import Callable, Generator
from urllib.request import urlopen

import h11


@contextmanager
def socket_server(
    handler: Callable[..., socketserver.BaseRequestHandler]
) -> Generator[socketserver.TCPServer, None, None]:
    httpd = socketserver.TCPServer(("127.0.0.1", 0), handler)
    thread = threading.Thread(
        target=httpd.serve_forever, kwargs={"poll_interval": 0.01}
    )
    thread.daemon = True
    try:
        thread.start()
        yield httpd
    finally:
        httpd.shutdown()


test_file_path = os.path.join(os.path.dirname(__file__), "data/test-file")
with open(test_file_path, "rb") as f:
    test_file_data = f.read()


class SingleMindedRequestHandler(SimpleHTTPRequestHandler):
    def translate_path(self, path: str) -> str:
        return test_file_path


def test_h11_as_client() -> None:
    with socket_server(SingleMindedRequestHandler) as httpd:
        with closing(socket.create_connection(httpd.server_address)) as s:
            c = h11.Connection(h11.CLIENT)

            s.sendall(
                c.send(  # type: ignore[arg-type]
                    h11.Request(
                        method="GET", target="/foo", headers=[("Host", "localhost")]
                    )
                )
            )
            s.sendall(c.send(h11.EndOfMessage()))  # type: ignore[arg-type]

            data = bytearray()
            while True:
                event = c.next_event()
                print(event)
                if event is h11.NEED_DATA:
                    # Use a small read buffer to make things more challenging
                    # and exercise more paths :-)
                    c.receive_data(s.recv(10))
                    continue
                if type(event) is h11.Response:
                    assert event.status_code == 200
                if type(event) is h11.Data:
                    data += event.data
                if type(event) is h11.EndOfMessage:
                    break
            assert bytes(data) == test_file_data


class H11RequestHandler(socketserver.BaseRequestHandler):
    def handle(self) -> None:
        with closing(self.request) as s:
            c = h11.Connection(h11.SERVER)
            request = None
            while True:
                event = c.next_event()
                if event is h11.NEED_DATA:
                    # Use a small read buffer to make things more challenging
                    # and exercise more paths :-)
                    c.receive_data(s.recv(10))
                    continue
                if type(event) is h11.Request:
                    request = event
                if type(event) is h11.EndOfMessage:
                    break
            assert request is not None
            info = json.dumps(
                {
                    "method": request.method.decode("ascii"),
                    "target": request.target.decode("ascii"),
                    "headers": {
                        name.decode("ascii"): value.decode("ascii")
                        for (name, value) in request.headers
                    },
                }
            )
            s.sendall(c.send(h11.Response(status_code=200, headers=[])))  # type: ignore[arg-type]
            s.sendall(c.send(h11.Data(data=info.encode("ascii"))))
            s.sendall(c.send(h11.EndOfMessage()))


def test_h11_as_server() -> None:
    with socket_server(H11RequestHandler) as httpd:
        host, port = httpd.server_address
        url = "http://{}:{}/some-path".format(host, port)
        with closing(urlopen(url)) as f:
            assert f.getcode() == 200
            data = f.read()
    info = json.loads(data.decode("ascii"))
    print(info)
    assert info["method"] == "GET"
    assert info["target"] == "/some-path"
    assert "urllib" in info["headers"]["user-agent"]