radames commited on
Commit
0bf34eb
·
2 Parent(s): 66b5681 ef88349

Merge branch 'main' into space-sdturbo

Browse files
app.py DELETED
@@ -1,17 +0,0 @@
1
- from fastapi import FastAPI
2
-
3
- from config import args
4
- from device import device, torch_dtype
5
- from app_init import init_app
6
- from user_queue import user_data
7
- from util import get_pipeline_class
8
-
9
- print("DEVICE:", device)
10
- print("TORCH_DTYPE:", torch_dtype)
11
- args.pretty_print()
12
-
13
- app = FastAPI()
14
-
15
- pipeline_class = get_pipeline_class(args.pipeline)
16
- pipeline = pipeline_class(args, device, torch_dtype)
17
- init_app(app, user_data, args, pipeline)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_init.py DELETED
@@ -1,163 +0,0 @@
1
- from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
2
- from fastapi.responses import StreamingResponse, JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.staticfiles import StaticFiles
5
- from fastapi import Request
6
- import markdown2
7
-
8
- import logging
9
- import traceback
10
- from config import Args
11
- from user_queue import UserData
12
- import uuid
13
- import time
14
- from types import SimpleNamespace
15
- from util import pil_to_frame, bytes_to_pil, is_firefox
16
- import asyncio
17
- import os
18
- import time
19
-
20
- THROTTLE = 1.0 / 120
21
-
22
-
23
- def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
24
- app.add_middleware(
25
- CORSMiddleware,
26
- allow_origins=["*"],
27
- allow_credentials=True,
28
- allow_methods=["*"],
29
- allow_headers=["*"],
30
- )
31
-
32
- @app.websocket("/api/ws")
33
- async def websocket_endpoint(websocket: WebSocket):
34
- await websocket.accept()
35
- user_count = user_data.get_user_count()
36
- if args.max_queue_size > 0 and user_count >= args.max_queue_size:
37
- print("Server is full")
38
- await websocket.send_json({"status": "error", "message": "Server is full"})
39
- await websocket.close()
40
- return
41
- try:
42
- user_id = uuid.uuid4()
43
- print(f"New user connected: {user_id}")
44
-
45
- await user_data.create_user(user_id, websocket)
46
- await websocket.send_json(
47
- {"status": "connected", "message": "Connected", "userId": str(user_id)}
48
- )
49
- await websocket.send_json({"status": "send_frame"})
50
- await handle_websocket_data(user_id, websocket)
51
- except WebSocketDisconnect as e:
52
- logging.error(f"WebSocket Error: {e}, {user_id}")
53
- traceback.print_exc()
54
- finally:
55
- print(f"User disconnected: {user_id}")
56
- user_data.delete_user(user_id)
57
-
58
- async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
59
- if not user_data.check_user(user_id):
60
- return HTTPException(status_code=404, detail="User not found")
61
- last_time = time.time()
62
- try:
63
- while True:
64
- data = await websocket.receive_json()
65
- if data["status"] != "next_frame":
66
- asyncio.sleep(THROTTLE)
67
- continue
68
-
69
- params = await websocket.receive_json()
70
- params = pipeline.InputParams(**params)
71
- info = pipeline.Info()
72
- params = SimpleNamespace(**params.dict())
73
- if info.input_mode == "image":
74
- image_data = await websocket.receive_bytes()
75
- if len(image_data) == 0:
76
- await websocket.send_json({"status": "send_frame"})
77
- continue
78
- params.image = bytes_to_pil(image_data)
79
- await user_data.update_data(user_id, params)
80
- await websocket.send_json({"status": "wait"})
81
- if args.timeout > 0 and time.time() - last_time > args.timeout:
82
- await websocket.send_json(
83
- {
84
- "status": "timeout",
85
- "message": "Your session has ended",
86
- "userId": user_id,
87
- }
88
- )
89
- await websocket.close()
90
- return
91
- await asyncio.sleep(THROTTLE)
92
-
93
- except Exception as e:
94
- logging.error(f"Error: {e}")
95
- traceback.print_exc()
96
-
97
- @app.get("/api/queue")
98
- async def get_queue_size():
99
- queue_size = user_data.get_user_count()
100
- return JSONResponse({"queue_size": queue_size})
101
-
102
- @app.get("/api/stream/{user_id}")
103
- async def stream(user_id: uuid.UUID, request: Request):
104
- try:
105
- print(f"New stream request: {user_id}")
106
-
107
- async def generate():
108
- websocket = user_data.get_websocket(user_id)
109
- last_params = SimpleNamespace()
110
- while True:
111
- last_time = time.time()
112
- params = await user_data.get_latest_data(user_id)
113
- if not vars(params) or params.__dict__ == last_params.__dict__:
114
- await websocket.send_json({"status": "send_frame"})
115
- continue
116
-
117
- last_params = params
118
- image = pipeline.predict(params)
119
-
120
- if image is None:
121
- await websocket.send_json({"status": "send_frame"})
122
- continue
123
- frame = pil_to_frame(image)
124
- yield frame
125
- # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
126
- if not is_firefox(request.headers["user-agent"]):
127
- yield frame
128
- await websocket.send_json({"status": "send_frame"})
129
- if args.debug:
130
- print(f"Time taken: {time.time() - last_time}")
131
-
132
- return StreamingResponse(
133
- generate(),
134
- media_type="multipart/x-mixed-replace;boundary=frame",
135
- headers={"Cache-Control": "no-cache"},
136
- )
137
- except Exception as e:
138
- logging.error(f"Streaming Error: {e}, {user_id} ")
139
- traceback.print_exc()
140
- return HTTPException(status_code=404, detail="User not found")
141
-
142
- # route to setup frontend
143
- @app.get("/api/settings")
144
- async def settings():
145
- info_schema = pipeline.Info.schema()
146
- info = pipeline.Info()
147
- if info.page_content:
148
- page_content = markdown2.markdown(info.page_content)
149
-
150
- input_params = pipeline.InputParams.schema()
151
- return JSONResponse(
152
- {
153
- "info": info_schema,
154
- "input_params": input_params,
155
- "max_queue_size": args.max_queue_size,
156
- "page_content": page_content if info.page_content else "",
157
- }
158
- )
159
-
160
- if not os.path.exists("public"):
161
- os.makedirs("public")
162
-
163
- app.mount("/", StaticFiles(directory="public", html=True), name="public")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build-run.sh CHANGED
@@ -17,4 +17,4 @@ if [ -z ${COMPILE+x} ]; then
17
  fi
18
  echo -e "\033[1;32m\npipeline: $PIPELINE \033[0m"
19
  echo -e "\033[1;32m\ncompile: $COMPILE \033[0m"
20
- python3 run.py --port 7860 --host 0.0.0.0 --pipeline $PIPELINE $COMPILE
 
17
  fi
18
  echo -e "\033[1;32m\npipeline: $PIPELINE \033[0m"
19
  echo -e "\033[1;32m\ncompile: $COMPILE \033[0m"
20
+ python3 main.py --port 7860 --host 0.0.0.0 --pipeline $PIPELINE $COMPILE
config.py CHANGED
@@ -124,4 +124,5 @@ parser.add_argument(
124
  )
125
  parser.set_defaults(taesd=USE_TAESD)
126
 
127
- args = Args(**vars(parser.parse_args()))
 
 
124
  )
125
  parser.set_defaults(taesd=USE_TAESD)
126
 
127
+ config = Args(**vars(parser.parse_args()))
128
+ config.pretty_print()
connection_manager.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Union
2
+ from uuid import UUID
3
+ import asyncio
4
+ from fastapi import WebSocket
5
+ from starlette.websockets import WebSocketState
6
+ import logging
7
+ from types import SimpleNamespace
8
+
9
+ Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
10
+
11
+
12
+ class ServerFullException(Exception):
13
+ """Exception raised when the server is full."""
14
+
15
+ pass
16
+
17
+
18
+ class ConnectionManager:
19
+ def __init__(self):
20
+ self.active_connections: Connections = {}
21
+
22
+ async def connect(
23
+ self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0
24
+ ):
25
+ await websocket.accept()
26
+ user_count = self.get_user_count()
27
+ print(f"User count: {user_count}")
28
+ if max_queue_size > 0 and user_count >= max_queue_size:
29
+ print("Server is full")
30
+ await websocket.send_json({"status": "error", "message": "Server is full"})
31
+ await websocket.close()
32
+ raise ServerFullException("Server is full")
33
+ print(f"New user connected: {user_id}")
34
+ self.active_connections[user_id] = {
35
+ "websocket": websocket,
36
+ "queue": asyncio.Queue(),
37
+ }
38
+ await websocket.send_json(
39
+ {"status": "connected", "message": "Connected"},
40
+ )
41
+ await websocket.send_json({"status": "wait"})
42
+ await websocket.send_json({"status": "send_frame"})
43
+
44
+ def check_user(self, user_id: UUID) -> bool:
45
+ return user_id in self.active_connections
46
+
47
+ async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
48
+ user_session = self.active_connections.get(user_id)
49
+ if user_session:
50
+ queue = user_session["queue"]
51
+ while not queue.empty():
52
+ try:
53
+ queue.get_nowait()
54
+ except asyncio.QueueEmpty:
55
+ continue
56
+ await queue.put(new_data)
57
+
58
+ async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
59
+ user_session = self.active_connections.get(user_id)
60
+ if user_session:
61
+ queue = user_session["queue"]
62
+ try:
63
+ return await queue.get()
64
+ except asyncio.QueueEmpty:
65
+ return None
66
+
67
+ def delete_user(self, user_id: UUID):
68
+ user_session = self.active_connections.pop(user_id, None)
69
+ if user_session:
70
+ queue = user_session["queue"]
71
+ while not queue.empty():
72
+ try:
73
+ queue.get_nowait()
74
+ except asyncio.QueueEmpty:
75
+ continue
76
+
77
+ def get_user_count(self) -> int:
78
+ return len(self.active_connections)
79
+
80
+ def get_websocket(self, user_id: UUID) -> WebSocket:
81
+ user_session = self.active_connections.get(user_id)
82
+ if user_session:
83
+ websocket = user_session["websocket"]
84
+ if websocket.client_state == WebSocketState.CONNECTED:
85
+ return user_session["websocket"]
86
+ return None
87
+
88
+ async def disconnect(self, user_id: UUID):
89
+ websocket = self.get_websocket(user_id)
90
+ if websocket:
91
+ await websocket.close()
92
+ self.delete_user(user_id)
93
+
94
+ async def send_json(self, user_id: UUID, data: Dict):
95
+ try:
96
+ websocket = self.get_websocket(user_id)
97
+ if websocket:
98
+ await websocket.send_json(data)
99
+ except Exception as e:
100
+ logging.error(f"Error: Send json: {e}")
101
+
102
+ async def receive_json(self, user_id: UUID) -> Dict:
103
+ try:
104
+ websocket = self.get_websocket(user_id)
105
+ if websocket:
106
+ return await websocket.receive_json()
107
+ except Exception as e:
108
+ logging.error(f"Error: Receive json: {e}")
109
+
110
+ async def receive_bytes(self, user_id: UUID) -> bytes:
111
+ try:
112
+ websocket = self.get_websocket(user_id)
113
+ if websocket:
114
+ return await websocket.receive_bytes()
115
+ except Exception as e:
116
+ logging.error(f"Error: Receive bytes: {e}")
frontend/src/lib/components/Warning.svelte ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ export let message: string = '';
3
+
4
+ let timeout = 0;
5
+ $: if (message !== '') {
6
+ console.log('message', message);
7
+ clearTimeout(timeout);
8
+ timeout = setTimeout(() => {
9
+ message = '';
10
+ }, 5000);
11
+ }
12
+ </script>
13
+
14
+ {#if message}
15
+ <div class="fixed right-0 top-0 m-4 cursor-pointer" on:click={() => (message = '')}>
16
+ <div class="rounded bg-red-800 p-4 text-white">
17
+ {message}
18
+ </div>
19
+ <div class="bar transition-all duration-500" style="width: 0;"></div>
20
+ </div>
21
+ {/if}
22
+
23
+ <style lang="postcss" scoped>
24
+ .button {
25
+ @apply rounded bg-gray-700 font-normal text-white hover:bg-gray-800 disabled:cursor-not-allowed disabled:bg-gray-300 dark:disabled:bg-gray-700 dark:disabled:text-black;
26
+ }
27
+ </style>
frontend/src/lib/lcmLive.ts CHANGED
@@ -6,6 +6,7 @@ export enum LCMLiveStatus {
6
  DISCONNECTED = "disconnected",
7
  WAIT = "wait",
8
  SEND_FRAME = "send_frame",
 
9
  }
10
 
11
  const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
@@ -19,8 +20,9 @@ export const lcmLiveActions = {
19
  return new Promise((resolve, reject) => {
20
 
21
  try {
 
22
  const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
23
- }:${window.location.host}/api/ws`;
24
 
25
  websocket = new WebSocket(websocketURL);
26
  websocket.onopen = () => {
@@ -37,10 +39,9 @@ export const lcmLiveActions = {
37
  const data = JSON.parse(event.data);
38
  switch (data.status) {
39
  case "connected":
40
- const userId = data.userId;
41
  lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
42
  streamId.set(userId);
43
- resolve(userId);
44
  break;
45
  case "send_frame":
46
  lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
@@ -55,14 +56,16 @@ export const lcmLiveActions = {
55
  break;
56
  case "timeout":
57
  console.log("timeout");
58
- lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
59
  streamId.set(null);
60
- reject("timeout");
 
61
  case "error":
62
  console.log(data.message);
63
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
64
  streamId.set(null);
65
- reject(data.message);
 
66
  }
67
  };
68
 
@@ -86,12 +89,11 @@ export const lcmLiveActions = {
86
  }
87
  },
88
  async stop() {
89
-
90
  if (websocket) {
91
  websocket.close();
92
  }
93
  websocket = null;
94
- lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
95
  streamId.set(null);
96
  },
97
  };
 
6
  DISCONNECTED = "disconnected",
7
  WAIT = "wait",
8
  SEND_FRAME = "send_frame",
9
+ TIMEOUT = "timeout",
10
  }
11
 
12
  const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
 
20
  return new Promise((resolve, reject) => {
21
 
22
  try {
23
+ const userId = crypto.randomUUID();
24
  const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
25
+ }:${window.location.host}/api/ws/${userId}`;
26
 
27
  websocket = new WebSocket(websocketURL);
28
  websocket.onopen = () => {
 
39
  const data = JSON.parse(event.data);
40
  switch (data.status) {
41
  case "connected":
 
42
  lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
43
  streamId.set(userId);
44
+ resolve({ status: "connected", userId });
45
  break;
46
  case "send_frame":
47
  lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
 
56
  break;
57
  case "timeout":
58
  console.log("timeout");
59
+ lcmLiveStatus.set(LCMLiveStatus.TIMEOUT);
60
  streamId.set(null);
61
+ reject(new Error("timeout"));
62
+ break;
63
  case "error":
64
  console.log(data.message);
65
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
66
  streamId.set(null);
67
+ reject(new Error(data.message));
68
+ break;
69
  }
70
  };
71
 
 
89
  }
90
  },
91
  async stop() {
92
+ lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
93
  if (websocket) {
94
  websocket.close();
95
  }
96
  websocket = null;
 
97
  streamId.set(null);
98
  },
99
  };
frontend/src/routes/+page.svelte CHANGED
@@ -7,6 +7,7 @@
7
  import Button from '$lib/components/Button.svelte';
8
  import PipelineOptions from '$lib/components/PipelineOptions.svelte';
9
  import Spinner from '$lib/icons/spinner.svelte';
 
10
  import { lcmLiveStatus, lcmLiveActions, LCMLiveStatus } from '$lib/lcmLive';
11
  import { mediaStreamActions, onFrameChangeStore } from '$lib/mediaStream';
12
  import { getPipelineValues, deboucedPipelineValues } from '$lib/store';
@@ -18,7 +19,7 @@
18
  let maxQueueSize: number = 0;
19
  let currentQueueSize: number = 0;
20
  let queueCheckerRunning: boolean = false;
21
-
22
  onMount(() => {
23
  getSettings();
24
  });
@@ -57,23 +58,31 @@
57
  }
58
 
59
  $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
60
-
 
 
61
  let disabled = false;
62
  async function toggleLcmLive() {
63
- if (!isLCMRunning) {
64
- if (isImageMode) {
65
- await mediaStreamActions.enumerateDevices();
66
- await mediaStreamActions.start();
 
 
 
 
 
 
 
 
 
 
 
 
67
  }
68
- disabled = true;
69
- await lcmLiveActions.start(getSreamdata);
70
  disabled = false;
71
- toggleQueueChecker(false);
72
- } else {
73
- if (isImageMode) {
74
- mediaStreamActions.stop();
75
- }
76
- lcmLiveActions.stop();
77
  toggleQueueChecker(true);
78
  }
79
  }
@@ -86,6 +95,7 @@
86
  </svelte:head>
87
 
88
  <main class="container mx-auto flex max-w-5xl flex-col gap-3 px-4 py-4">
 
89
  <article class="text-center">
90
  {#if pageContent}
91
  {@html pageContent}
 
7
  import Button from '$lib/components/Button.svelte';
8
  import PipelineOptions from '$lib/components/PipelineOptions.svelte';
9
  import Spinner from '$lib/icons/spinner.svelte';
10
+ import Warning from '$lib/components/Warning.svelte';
11
  import { lcmLiveStatus, lcmLiveActions, LCMLiveStatus } from '$lib/lcmLive';
12
  import { mediaStreamActions, onFrameChangeStore } from '$lib/mediaStream';
13
  import { getPipelineValues, deboucedPipelineValues } from '$lib/store';
 
19
  let maxQueueSize: number = 0;
20
  let currentQueueSize: number = 0;
21
  let queueCheckerRunning: boolean = false;
22
+ let warningMessage: string = '';
23
  onMount(() => {
24
  getSettings();
25
  });
 
58
  }
59
 
60
  $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
61
+ $: if ($lcmLiveStatus === LCMLiveStatus.TIMEOUT) {
62
+ warningMessage = 'Session timed out. Please try again.';
63
+ }
64
  let disabled = false;
65
  async function toggleLcmLive() {
66
+ try {
67
+ if (!isLCMRunning) {
68
+ if (isImageMode) {
69
+ await mediaStreamActions.enumerateDevices();
70
+ await mediaStreamActions.start();
71
+ }
72
+ disabled = true;
73
+ await lcmLiveActions.start(getSreamdata);
74
+ disabled = false;
75
+ toggleQueueChecker(false);
76
+ } else {
77
+ if (isImageMode) {
78
+ mediaStreamActions.stop();
79
+ }
80
+ lcmLiveActions.stop();
81
+ toggleQueueChecker(true);
82
  }
83
+ } catch (e) {
84
+ warningMessage = e instanceof Error ? e.message : '';
85
  disabled = false;
 
 
 
 
 
 
86
  toggleQueueChecker(true);
87
  }
88
  }
 
95
  </svelte:head>
96
 
97
  <main class="container mx-auto flex max-w-5xl flex-col gap-3 px-4 py-4">
98
+ <Warning bind:message={warningMessage}></Warning>
99
  <article class="text-center">
100
  {#if pageContent}
101
  {@html pageContent}
main.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
2
+ from fastapi.responses import StreamingResponse, JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi import Request
6
+ import markdown2
7
+
8
+ import logging
9
+ from config import config, Args
10
+ from connection_manager import ConnectionManager
11
+ import uuid
12
+ import time
13
+ from types import SimpleNamespace
14
+ from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class
15
+ from device import device, torch_dtype
16
+ import asyncio
17
+ import os
18
+ import time
19
+ import torch
20
+
21
+
22
+ THROTTLE = 1.0 / 120
23
+
24
+
25
+ class App:
26
+ def __init__(self, config: Args, pipeline):
27
+ self.args = config
28
+ self.pipeline = pipeline
29
+ self.app = FastAPI()
30
+ self.conn_manager = ConnectionManager()
31
+ self.init_app()
32
+
33
+ def init_app(self):
34
+ self.app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ @self.app.websocket("/api/ws/{user_id}")
43
+ async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket):
44
+ try:
45
+ await self.conn_manager.connect(
46
+ user_id, websocket, self.args.max_queue_size
47
+ )
48
+ await handle_websocket_data(user_id)
49
+ except ServerFullException as e:
50
+ logging.error(f"Server Full: {e}")
51
+ finally:
52
+ await self.conn_manager.disconnect(user_id)
53
+ logging.info(f"User disconnected: {user_id}")
54
+
55
+ async def handle_websocket_data(user_id: uuid.UUID):
56
+ if not self.conn_manager.check_user(user_id):
57
+ return HTTPException(status_code=404, detail="User not found")
58
+ last_time = time.time()
59
+ try:
60
+ while True:
61
+ if (
62
+ self.args.timeout > 0
63
+ and time.time() - last_time > self.args.timeout
64
+ ):
65
+ await self.conn_manager.send_json(
66
+ user_id,
67
+ {
68
+ "status": "timeout",
69
+ "message": "Your session has ended",
70
+ },
71
+ )
72
+ await self.conn_manager.disconnect(user_id)
73
+ return
74
+ data = await self.conn_manager.receive_json(user_id)
75
+ if data["status"] != "next_frame":
76
+ asyncio.sleep(THROTTLE)
77
+ continue
78
+
79
+ params = await self.conn_manager.receive_json(user_id)
80
+ params = pipeline.InputParams(**params)
81
+ info = pipeline.Info()
82
+ params = SimpleNamespace(**params.dict())
83
+ if info.input_mode == "image":
84
+ image_data = await self.conn_manager.receive_bytes(user_id)
85
+ if len(image_data) == 0:
86
+ await self.conn_manager.send_json(
87
+ user_id, {"status": "send_frame"}
88
+ )
89
+ await asyncio.sleep(THROTTLE)
90
+ continue
91
+ params.image = bytes_to_pil(image_data)
92
+ await self.conn_manager.update_data(user_id, params)
93
+ await self.conn_manager.send_json(user_id, {"status": "wait"})
94
+
95
+ except Exception as e:
96
+ logging.error(f"Websocket Error: {e}, {user_id} ")
97
+ await self.conn_manager.disconnect(user_id)
98
+
99
+ @self.app.get("/api/queue")
100
+ async def get_queue_size():
101
+ queue_size = self.conn_manager.get_user_count()
102
+ return JSONResponse({"queue_size": queue_size})
103
+
104
+ @self.app.get("/api/stream/{user_id}")
105
+ async def stream(user_id: uuid.UUID, request: Request):
106
+ try:
107
+
108
+ async def generate():
109
+ last_params = SimpleNamespace()
110
+ while True:
111
+ last_time = time.time()
112
+ params = await self.conn_manager.get_latest_data(user_id)
113
+ if not vars(params) or params.__dict__ == last_params.__dict__:
114
+ await self.conn_manager.send_json(
115
+ user_id, {"status": "send_frame"}
116
+ )
117
+ continue
118
+
119
+ last_params = params
120
+ image = pipeline.predict(params)
121
+ if image is None:
122
+ await self.conn_manager.send_json(
123
+ user_id, {"status": "send_frame"}
124
+ )
125
+ continue
126
+ frame = pil_to_frame(image)
127
+ yield frame
128
+ # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
129
+ if not is_firefox(request.headers["user-agent"]):
130
+ yield frame
131
+ await self.conn_manager.send_json(
132
+ user_id, {"status": "send_frame"}
133
+ )
134
+ if self.args.debug:
135
+ print(f"Time taken: {time.time() - last_time}")
136
+
137
+ return StreamingResponse(
138
+ generate(),
139
+ media_type="multipart/x-mixed-replace;boundary=frame",
140
+ headers={"Cache-Control": "no-cache"},
141
+ )
142
+ except Exception as e:
143
+ logging.error(f"Streaming Error: {e}, {user_id} ")
144
+ return HTTPException(status_code=404, detail="User not found")
145
+
146
+ # route to setup frontend
147
+ @self.app.get("/api/settings")
148
+ async def settings():
149
+ info_schema = pipeline.Info.schema()
150
+ info = pipeline.Info()
151
+ if info.page_content:
152
+ page_content = markdown2.markdown(info.page_content)
153
+
154
+ input_params = pipeline.InputParams.schema()
155
+ return JSONResponse(
156
+ {
157
+ "info": info_schema,
158
+ "input_params": input_params,
159
+ "max_queue_size": self.args.max_queue_size,
160
+ "page_content": page_content if info.page_content else "",
161
+ }
162
+ )
163
+
164
+ if not os.path.exists("public"):
165
+ os.makedirs("public")
166
+
167
+ self.app.mount("/", StaticFiles(directory="public", html=True), name="public")
168
+
169
+
170
+ pipeline_class = get_pipeline_class(config.pipeline)
171
+ pipeline = pipeline_class(config, device, torch_dtype)
172
+ app = App(config, pipeline).app
173
+
174
+ if __name__ == "__main__":
175
+ import uvicorn
176
+
177
+ uvicorn.run(
178
+ "main:app",
179
+ host=config.host,
180
+ port=config.port,
181
+ reload=config.reload,
182
+ ssl_certfile=config.ssl_certfile,
183
+ ssl_keyfile=config.ssl_keyfile,
184
+ )
run.py DELETED
@@ -1,12 +0,0 @@
1
- if __name__ == "__main__":
2
- import uvicorn
3
- from config import args
4
-
5
- uvicorn.run(
6
- "app:app",
7
- host=args.host,
8
- port=args.port,
9
- reload=args.reload,
10
- ssl_certfile=args.ssl_certfile,
11
- ssl_keyfile=args.ssl_keyfile,
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
user_queue.py DELETED
@@ -1,63 +0,0 @@
1
- from typing import Dict
2
- from uuid import UUID
3
- import asyncio
4
- from fastapi import WebSocket
5
- from types import SimpleNamespace
6
- from typing import Dict
7
- from typing import Union
8
-
9
- UserDataContent = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
10
-
11
-
12
- class UserData:
13
- def __init__(self):
14
- self.data_content: Dict[UUID, UserDataContent] = {}
15
-
16
- async def create_user(self, user_id: UUID, websocket: WebSocket):
17
- self.data_content[user_id] = {
18
- "websocket": websocket,
19
- "queue": asyncio.Queue(),
20
- }
21
- await asyncio.sleep(1)
22
-
23
- def check_user(self, user_id: UUID) -> bool:
24
- return user_id in self.data_content
25
-
26
- async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
27
- user_session = self.data_content[user_id]
28
- queue = user_session["queue"]
29
- while not queue.empty():
30
- try:
31
- queue.get_nowait()
32
- except asyncio.QueueEmpty:
33
- continue
34
- await queue.put(new_data)
35
-
36
- async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
37
- user_session = self.data_content[user_id]
38
- queue = user_session["queue"]
39
-
40
- try:
41
- return await queue.get()
42
- except asyncio.QueueEmpty:
43
- return None
44
-
45
- def delete_user(self, user_id: UUID):
46
- user_session = self.data_content[user_id]
47
- queue = user_session["queue"]
48
- while not queue.empty():
49
- try:
50
- queue.get_nowait()
51
- except asyncio.QueueEmpty:
52
- continue
53
- if user_id in self.data_content:
54
- del self.data_content[user_id]
55
-
56
- def get_user_count(self) -> int:
57
- return len(self.data_content)
58
-
59
- def get_websocket(self, user_id: UUID) -> WebSocket:
60
- return self.data_content[user_id]["websocket"]
61
-
62
-
63
- user_data = UserData()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util.py CHANGED
@@ -1,7 +1,5 @@
1
  from importlib import import_module
2
  from types import ModuleType
3
- from typing import Dict, Any
4
- from pydantic import BaseModel as PydanticBaseModel, Field
5
  from PIL import Image
6
  import io
7
 
 
1
  from importlib import import_module
2
  from types import ModuleType
 
 
3
  from PIL import Image
4
  import io
5