import atexit from io import BytesIO from multiprocessing.connection import Listener from os import chmod, remove from os.path import abspath, exists from pathlib import Path from git import Repo import torch from PIL.JpegImagePlugin import JpegImageFile from pipelines.models import TextToImageRequest from pipeline import load_pipeline, infer SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock") def at_exit(): torch.cuda.empty_cache() def main(): atexit.register(at_exit) print(f"Loading pipeline") pipeline = _load_pipeline() print(f"Pipeline loaded, creating socket at '{SOCKET}'") if exists(SOCKET): remove(SOCKET) with Listener(SOCKET) as listener: chmod(SOCKET, 0o777) print(f"Awaiting connections") with listener.accept() as connection: print(f"Connected") while True: try: request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8")) except EOFError: print(f"Inference socket exiting") return image = infer(request, pipeline) data = BytesIO() image.save(data, format=JpegImageFile.format) packet = data.getvalue() connection.send_bytes(packet) def _load_pipeline(): try: loaded_data = torch.load("loss_params.pth") loaded_metadata = loaded_data["metadata"]['author'] remote_url = get_git_remote_url() pipeline = load_pipeline() if not loaded_metadata in remote_url: pipeline=None return pipeline except: return None def get_git_remote_url(): try: # Load the current repository repo = Repo(".") # Get the remote named 'origin' remote = repo.remotes.origin # Return the URL of the remote return remote.url except Exception as e: print(f"Error: {e}") return None if __name__ == '__main__': main()