File size: 1,389 Bytes
44368cf
f18e889
584b6a4
c7f1a46
 
aa65bc3
f18e889
44368cf
 
0972315
f18e889
 
 
 
7d60587
f18e889
 
44368cf
 
 
 
f18e889
44368cf
 
7ac68ef
f18e889
 
c7f1a46
 
 
 
7ac68ef
fcd4555
71df5b2
 
7ac68ef
584b6a4
 
7ac68ef
f18e889
921dd6f
 
 
 
 
 
f18e889
 
 
 
0972315
f18e889
 
 
584b6a4
f18e889
 
 
44368cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import atexit
from io import BytesIO
from multiprocessing.connection import Listener
from os import chmod, remove
from os.path import abspath, exists
from pathlib import Path

import torch

from PIL.JpegImagePlugin import JpegImageFile
from pipelines.models import TextToImageRequest

from pipeline import load_pipeline, infer

SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")


def at_exit():
    torch.cuda.empty_cache()


def main():
    atexit.register(at_exit)

    print(f"Loading pipeline")
    pipeline = load_pipeline()

    print(f"Pipeline loaded, creating socket at '{SOCKET}'")

    if exists(SOCKET):
        remove(SOCKET)

    with Listener(SOCKET) as listener:
        chmod(SOCKET, 0o777)

        print(f"Awaiting connections")
        with listener.accept() as connection:
            print(f"Connected")

            while True:
                try:
                    request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
                except EOFError:
                    print(f"Inference socket exiting")

                    return

                image = infer(request, pipeline)

                data = BytesIO()
                image.save(data, format=JpegImageFile.format)

                packet = data.getvalue()

                connection.send_bytes(packet)


if __name__ == '__main__':
    main()