radames commited on
Commit
9a8789a
·
1 Parent(s): 2ab3299

refactor: fix websocket errors

Browse files
frontend/src/lib/components/ImagePlayer.svelte CHANGED
@@ -7,8 +7,8 @@
7
  import Expand from '$lib/icons/expand.svelte';
8
  import { snapImage, expandWindow } from '$lib/utils';
9
 
10
- $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
11
- $: console.log('isLCMRunning', isLCMRunning);
12
  let imageEl: HTMLImageElement;
13
  let expandedWindow: Window;
14
  let isExpanded = false;
@@ -40,12 +40,26 @@
40
  class="relative mx-auto aspect-square max-w-lg self-center overflow-hidden rounded-lg border border-slate-300"
41
  >
42
  <!-- svelte-ignore a11y-missing-attribute -->
43
- {#if isLCMRunning}
 
 
 
 
 
 
44
  {#if !isExpanded}
 
45
  <img
46
  bind:this={imageEl}
47
  class="aspect-square w-full rounded-lg"
48
  src={'/api/stream/' + $streamId}
 
 
 
 
 
 
 
49
  />
50
  {/if}
51
  <div class="absolute bottom-1 right-1">
@@ -65,6 +79,13 @@
65
  <Floppy classList={''} />
66
  </Button>
67
  </div>
 
 
 
 
 
 
 
68
  {:else}
69
  <img
70
  class="aspect-square w-full rounded-lg"
 
7
  import Expand from '$lib/icons/expand.svelte';
8
  import { snapImage, expandWindow } from '$lib/utils';
9
 
10
+ $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED &&
11
+ $lcmLiveStatus !== LCMLiveStatus.ERROR;
12
  let imageEl: HTMLImageElement;
13
  let expandedWindow: Window;
14
  let isExpanded = false;
 
40
  class="relative mx-auto aspect-square max-w-lg self-center overflow-hidden rounded-lg border border-slate-300"
41
  >
42
  <!-- svelte-ignore a11y-missing-attribute -->
43
+ {#if $lcmLiveStatus === LCMLiveStatus.CONNECTING}
44
+ <!-- Show connecting spinner -->
45
+ <div class="flex items-center justify-center h-full w-full">
46
+ <div class="animate-spin rounded-full h-16 w-16 border-b-2 border-white"></div>
47
+ <p class="text-white ml-2">Connecting...</p>
48
+ </div>
49
+ {:else if isLCMRunning}
50
  {#if !isExpanded}
51
+ <!-- Handle image error by adding onerror event -->
52
  <img
53
  bind:this={imageEl}
54
  class="aspect-square w-full rounded-lg"
55
  src={'/api/stream/' + $streamId}
56
+ on:error={(e) => {
57
+ console.error('Image stream error:', e);
58
+ // If stream fails to load, set status to error
59
+ if ($lcmLiveStatus !== LCMLiveStatus.ERROR) {
60
+ lcmLiveStatus.set(LCMLiveStatus.ERROR);
61
+ }
62
+ }}
63
  />
64
  {/if}
65
  <div class="absolute bottom-1 right-1">
 
79
  <Floppy classList={''} />
80
  </Button>
81
  </div>
82
+ {:else if $lcmLiveStatus === LCMLiveStatus.ERROR}
83
+ <!-- Show error state with red border -->
84
+ <div class="flex items-center justify-center h-full w-full border-2 border-red-500 rounded-lg bg-gray-900">
85
+ <p class="text-center text-white p-4">
86
+ Connection error
87
+ </p>
88
+ </div>
89
  {:else}
90
  <img
91
  class="aspect-square w-full rounded-lg"
frontend/src/lib/lcmLive.ts CHANGED
@@ -1,11 +1,13 @@
1
- import { writable } from 'svelte/store';
2
 
3
  export enum LCMLiveStatus {
4
  CONNECTED = 'connected',
5
  DISCONNECTED = 'disconnected',
 
6
  WAIT = 'wait',
7
  SEND_FRAME = 'send_frame',
8
- TIMEOUT = 'timeout'
 
9
  }
10
 
11
  const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
@@ -13,85 +15,194 @@ const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
13
  export const lcmLiveStatus = writable<LCMLiveStatus>(initStatus);
14
  export const streamId = writable<string | null>(null);
15
 
 
16
  let websocket: WebSocket | null = null;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  export const lcmLiveActions = {
18
  async start(getSreamdata: () => any[]) {
19
  return new Promise((resolve, reject) => {
20
  try {
 
 
 
21
  const userId = crypto.randomUUID();
22
  const websocketURL = `${
23
  window.location.protocol === 'https:' ? 'wss' : 'ws'
24
  }:${window.location.host}/api/ws/${userId}`;
25
 
 
 
 
 
 
26
  websocket = new WebSocket(websocketURL);
 
 
 
 
 
 
 
 
 
 
 
 
27
  websocket.onopen = () => {
 
28
  console.log('Connected to websocket');
29
  };
30
- websocket.onclose = () => {
31
- lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
32
- console.log('Disconnected from websocket');
 
 
 
 
 
 
 
 
 
 
 
33
  };
 
34
  websocket.onerror = (err) => {
35
- console.error(err);
 
 
 
 
36
  };
 
37
  websocket.onmessage = (event) => {
38
- const data = JSON.parse(event.data);
39
- switch (data.status) {
40
- case 'connected':
41
- lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
42
- streamId.set(userId);
43
- resolve({ status: 'connected', userId });
44
- break;
45
- case 'send_frame':
46
- lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
47
- const streamData = getSreamdata();
48
- websocket?.send(JSON.stringify({ status: 'next_frame' }));
49
- for (const d of streamData) {
50
- this.send(d);
51
- }
52
- break;
53
- case 'wait':
54
- lcmLiveStatus.set(LCMLiveStatus.WAIT);
55
- break;
56
- case 'timeout':
57
- console.log('timeout');
58
- lcmLiveStatus.set(LCMLiveStatus.TIMEOUT);
59
- streamId.set(null);
60
- reject(new Error('timeout'));
61
- break;
62
- case 'error':
63
- console.log(data.message);
64
- lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
65
- streamId.set(null);
66
- reject(new Error(data.message));
67
- break;
 
 
 
 
 
 
 
 
 
 
 
68
  }
69
  };
70
  } catch (err) {
71
- console.error(err);
72
- lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
73
  streamId.set(null);
74
  reject(err);
75
  }
76
  });
77
  },
78
  send(data: Blob | { [key: string]: any }) {
79
- if (websocket && websocket.readyState === WebSocket.OPEN) {
80
- if (data instanceof Blob) {
81
- websocket.send(data);
 
 
 
 
82
  } else {
83
- websocket.send(JSON.stringify(data));
 
 
 
 
 
 
 
 
 
84
  }
85
- } else {
86
- console.log('WebSocket not connected');
 
 
87
  }
88
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  async stop() {
90
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
91
- if (websocket) {
92
- websocket.close();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  }
94
- websocket = null;
95
- streamId.set(null);
96
  }
97
  };
 
1
+ import { get, writable } from 'svelte/store';
2
 
3
  export enum LCMLiveStatus {
4
  CONNECTED = 'connected',
5
  DISCONNECTED = 'disconnected',
6
+ CONNECTING = 'connecting',
7
  WAIT = 'wait',
8
  SEND_FRAME = 'send_frame',
9
+ TIMEOUT = 'timeout',
10
+ ERROR = 'error'
11
  }
12
 
13
  const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
 
15
  export const lcmLiveStatus = writable<LCMLiveStatus>(initStatus);
16
  export const streamId = writable<string | null>(null);
17
 
18
+ // WebSocket connection
19
  let websocket: WebSocket | null = null;
20
+ // Flag to track intentional connection closure
21
+ let intentionalClosure = false;
22
+
23
+ // Register browser unload event listener to properly close WebSockets
24
+ if (typeof window !== 'undefined') {
25
+ window.addEventListener('beforeunload', () => {
26
+ // Mark any closure during page unload as intentional
27
+ intentionalClosure = true;
28
+ // Close the WebSocket properly if it exists
29
+ if (websocket && websocket.readyState === WebSocket.OPEN) {
30
+ websocket.close(1000, 'Page unload');
31
+ }
32
+ });
33
+ }
34
  export const lcmLiveActions = {
35
  async start(getSreamdata: () => any[]) {
36
  return new Promise((resolve, reject) => {
37
  try {
38
+ // Set connecting status immediately
39
+ lcmLiveStatus.set(LCMLiveStatus.CONNECTING);
40
+
41
  const userId = crypto.randomUUID();
42
  const websocketURL = `${
43
  window.location.protocol === 'https:' ? 'wss' : 'ws'
44
  }:${window.location.host}/api/ws/${userId}`;
45
 
46
+ // Close any existing connection first
47
+ if (websocket && websocket.readyState !== WebSocket.CLOSED) {
48
+ websocket.close();
49
+ }
50
+
51
  websocket = new WebSocket(websocketURL);
52
+
53
+ // Set a connection timeout
54
+ const connectionTimeout = setTimeout(() => {
55
+ if (websocket && websocket.readyState !== WebSocket.OPEN) {
56
+ console.error('WebSocket connection timeout');
57
+ lcmLiveStatus.set(LCMLiveStatus.ERROR);
58
+ streamId.set(null);
59
+ reject(new Error('Connection timeout. Please try again.'));
60
+ websocket.close();
61
+ }
62
+ }, 10000); // 10 second timeout
63
+
64
  websocket.onopen = () => {
65
+ clearTimeout(connectionTimeout);
66
  console.log('Connected to websocket');
67
  };
68
+
69
+ websocket.onclose = (event) => {
70
+ clearTimeout(connectionTimeout);
71
+ console.log(`Disconnected from websocket: ${event.code} ${event.reason}`);
72
+
73
+ // Only change status if we're not in ERROR state (which would mean we already handled the error)
74
+ if (get(lcmLiveStatus) !== LCMLiveStatus.ERROR) {
75
+ lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
76
+ }
77
+
78
+ // If connection was never established (close without open)
79
+ if (event.code === 1006 && streamId.get() === null) {
80
+ reject(new Error('Cannot connect to server. Please try again later.'));
81
+ }
82
  };
83
+
84
  websocket.onerror = (err) => {
85
+ clearTimeout(connectionTimeout);
86
+ console.error('WebSocket error:', err);
87
+ lcmLiveStatus.set(LCMLiveStatus.ERROR);
88
+ streamId.set(null);
89
+ reject(new Error('Connection error. Please try again.'));
90
  };
91
+
92
  websocket.onmessage = (event) => {
93
+ try {
94
+ const data = JSON.parse(event.data);
95
+ switch (data.status) {
96
+ case 'connected':
97
+ lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
98
+ streamId.set(userId);
99
+ resolve({ status: 'connected', userId });
100
+ break;
101
+ case 'send_frame':
102
+ lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
103
+ try {
104
+ const streamData = getSreamdata();
105
+ // Send as an object, not a string, to use the proper handling in the send method
106
+ this.send({ status: 'next_frame' });
107
+ for (const d of streamData) {
108
+ this.send(d);
109
+ }
110
+ } catch (error) {
111
+ console.error('Error sending frame data:', error);
112
+ }
113
+ break;
114
+ case 'wait':
115
+ lcmLiveStatus.set(LCMLiveStatus.WAIT);
116
+ break;
117
+ case 'timeout':
118
+ console.log('Session timeout');
119
+ lcmLiveStatus.set(LCMLiveStatus.TIMEOUT);
120
+ streamId.set(null);
121
+ reject(new Error('Session timeout. Please restart.'));
122
+ break;
123
+ case 'error':
124
+ console.error('Server error:', data.message);
125
+ lcmLiveStatus.set(LCMLiveStatus.ERROR);
126
+ streamId.set(null);
127
+ reject(new Error(data.message || 'Server error occurred'));
128
+ break;
129
+ default:
130
+ console.log('Unknown message status:', data.status);
131
+ }
132
+ } catch (error) {
133
+ console.error('Error handling websocket message:', error);
134
  }
135
  };
136
  } catch (err) {
137
+ console.error('Error initializing websocket:', err);
138
+ lcmLiveStatus.set(LCMLiveStatus.ERROR);
139
  streamId.set(null);
140
  reject(err);
141
  }
142
  });
143
  },
144
  send(data: Blob | { [key: string]: any }) {
145
+ try {
146
+ if (websocket && websocket.readyState === WebSocket.OPEN) {
147
+ if (data instanceof Blob) {
148
+ websocket.send(data);
149
+ } else {
150
+ websocket.send(JSON.stringify(data));
151
+ }
152
  } else {
153
+ const readyStateText = websocket
154
+ ? ['CONNECTING', 'OPEN', 'CLOSING', 'CLOSED'][websocket.readyState]
155
+ : 'null';
156
+ console.warn(`WebSocket not ready for sending: ${readyStateText}`);
157
+
158
+ // If WebSocket is closed unexpectedly, set status to disconnected
159
+ if (!websocket || websocket.readyState === WebSocket.CLOSED) {
160
+ lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
161
+ streamId.set(null);
162
+ }
163
  }
164
+ } catch (error) {
165
+ console.error('Error sending data through WebSocket:', error);
166
+ // Handle WebSocket error by forcing disconnection
167
+ this.stop();
168
  }
169
  },
170
+
171
+ async reconnect(getSreamdata: () => any[]) {
172
+ try {
173
+ await this.stop();
174
+ // Small delay to ensure clean disconnection before reconnecting
175
+ await new Promise((resolve) => setTimeout(resolve, 500));
176
+ return await this.start(getSreamdata);
177
+ } catch (error) {
178
+ console.error('Reconnection failed:', error);
179
+ throw error;
180
+ }
181
+ },
182
+
183
  async stop() {
184
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
185
+ try {
186
+ if (websocket) {
187
+ // Only attempt to close if not already closed
188
+ if (websocket.readyState !== WebSocket.CLOSED) {
189
+ // Set up onclose handler to clean up only
190
+ websocket.onclose = () => {
191
+ console.log('WebSocket closed cleanly during stop()');
192
+ };
193
+
194
+ // Set up onerror to be silent during intentional closure
195
+ websocket.onerror = () => {};
196
+
197
+ websocket.close(1000, 'Client initiated disconnect');
198
+ }
199
+ }
200
+ } catch (error) {
201
+ console.error('Error during WebSocket closure:', error);
202
+ } finally {
203
+ // Always clean up references
204
+ websocket = null;
205
+ streamId.set(null);
206
  }
 
 
207
  }
208
  };
frontend/src/routes/+page.svelte CHANGED
@@ -57,35 +57,86 @@
57
  }
58
  }
59
 
60
- $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
61
- $: if ($lcmLiveStatus === LCMLiveStatus.TIMEOUT) {
62
- warningMessage = 'Session timed out. Please try again.';
 
 
 
 
 
 
 
 
63
  }
64
  let disabled = false;
65
  async function toggleLcmLive() {
66
  try {
67
  if (!isLCMRunning) {
68
- if (isImageMode) {
69
- await mediaStreamActions.enumerateDevices();
70
- await mediaStreamActions.start();
71
  }
 
 
 
72
  disabled = true;
73
- await lcmLiveActions.start(getSreamdata);
74
- disabled = false;
75
- toggleQueueChecker(false);
 
 
 
 
 
 
 
 
 
 
76
  } else {
77
- if (isImageMode) {
78
- mediaStreamActions.stop();
 
 
 
 
 
 
 
 
 
79
  }
80
- lcmLiveActions.stop();
81
- toggleQueueChecker(true);
82
  }
83
  } catch (e) {
84
- warningMessage = e instanceof Error ? e.message : '';
 
85
  disabled = false;
86
  toggleQueueChecker(true);
87
  }
88
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  </script>
90
 
91
  <svelte:head>
@@ -111,6 +162,17 @@
111
  > and run it on your own GPU.
112
  </p>
113
  {/if}
 
 
 
 
 
 
 
 
 
 
 
114
  </article>
115
  {#if pipelineParams}
116
  <article class="my-3 grid grid-cols-1 gap-3 sm:grid-cols-4">
@@ -127,7 +189,9 @@
127
  </div>
128
  <div class="sm:col-span-4 sm:row-start-2">
129
  <Button on:click={toggleLcmLive} {disabled} classList={'text-lg my-1 p-2'}>
130
- {#if isLCMRunning}
 
 
131
  Stop
132
  {:else}
133
  Start
 
57
  }
58
  }
59
 
60
+ $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED &&
61
+ $lcmLiveStatus !== LCMLiveStatus.ERROR;
62
+ $: isConnecting = $lcmLiveStatus === LCMLiveStatus.CONNECTING;
63
+
64
+ $: {
65
+ // Set warning messages based on lcmLiveStatus
66
+ if ($lcmLiveStatus === LCMLiveStatus.TIMEOUT) {
67
+ warningMessage = 'Session timed out. Please try again.';
68
+ } else if ($lcmLiveStatus === LCMLiveStatus.ERROR) {
69
+ warningMessage = 'Connection error occurred. Please try again.';
70
+ }
71
  }
72
  let disabled = false;
73
  async function toggleLcmLive() {
74
  try {
75
  if (!isLCMRunning) {
76
+ if (isConnecting) {
77
+ return; // Don't allow multiple connection attempts
 
78
  }
79
+
80
+ // Clear any previous warning messages
81
+ warningMessage = '';
82
  disabled = true;
83
+
84
+ try {
85
+ if (isImageMode) {
86
+ await mediaStreamActions.enumerateDevices();
87
+ await mediaStreamActions.start();
88
+ }
89
+
90
+ await lcmLiveActions.start(getSreamdata);
91
+ toggleQueueChecker(false);
92
+ } finally {
93
+ // Always re-enable the button even if there was an error
94
+ disabled = false;
95
+ }
96
  } else {
97
+ // Handle stopping - disable button during this process too
98
+ disabled = true;
99
+
100
+ try {
101
+ if (isImageMode) {
102
+ mediaStreamActions.stop();
103
+ }
104
+ await lcmLiveActions.stop();
105
+ toggleQueueChecker(true);
106
+ } finally {
107
+ disabled = false;
108
  }
 
 
109
  }
110
  } catch (e) {
111
+ console.error('Error in toggleLcmLive:', e);
112
+ warningMessage = e instanceof Error ? e.message : 'An unknown error occurred';
113
  disabled = false;
114
  toggleQueueChecker(true);
115
  }
116
  }
117
+
118
+ // Reconnect function for automatic reconnection
119
+ async function reconnect() {
120
+ try {
121
+ disabled = true;
122
+ warningMessage = 'Reconnecting...';
123
+
124
+ if (isImageMode) {
125
+ await mediaStreamActions.stop();
126
+ await mediaStreamActions.enumerateDevices();
127
+ await mediaStreamActions.start();
128
+ }
129
+
130
+ await lcmLiveActions.reconnect(getSreamdata);
131
+ warningMessage = '';
132
+ toggleQueueChecker(false);
133
+ } catch (e) {
134
+ warningMessage = e instanceof Error ? e.message : 'Reconnection failed';
135
+ toggleQueueChecker(true);
136
+ } finally {
137
+ disabled = false;
138
+ }
139
+ }
140
  </script>
141
 
142
  <svelte:head>
 
162
  > and run it on your own GPU.
163
  </p>
164
  {/if}
165
+
166
+ {#if $lcmLiveStatus === LCMLiveStatus.ERROR}
167
+ <p class="text-sm mt-2">
168
+ <button
169
+ class="text-blue-500 underline hover:no-underline"
170
+ on:click={reconnect}
171
+ disabled={disabled}>
172
+ Try reconnecting
173
+ </button>
174
+ </p>
175
+ {/if}
176
  </article>
177
  {#if pipelineParams}
178
  <article class="my-3 grid grid-cols-1 gap-3 sm:grid-cols-4">
 
189
  </div>
190
  <div class="sm:col-span-4 sm:row-start-2">
191
  <Button on:click={toggleLcmLive} {disabled} classList={'text-lg my-1 p-2'}>
192
+ {#if isConnecting}
193
+ Connecting...
194
+ {:else if isLCMRunning}
195
  Stop
196
  {:else}
197
  Start
server/config.py CHANGED
@@ -1,9 +1,10 @@
1
- from typing import NamedTuple
2
  import argparse
3
  import os
 
4
 
5
 
6
- class Args(NamedTuple):
7
  host: str
8
  port: int
9
  reload: bool
@@ -13,19 +14,30 @@ class Args(NamedTuple):
13
  torch_compile: bool
14
  taesd: bool
15
  pipeline: str
16
- ssl_certfile: str
17
- ssl_keyfile: str
18
  sfast: bool
19
  onediff: bool = False
20
  compel: bool = False
21
  debug: bool = False
22
 
23
- def pretty_print(self):
24
  print("\n")
25
- for field, value in self._asdict().items():
26
  print(f"{field}: {value}")
27
  print("\n")
28
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
31
  TIMEOUT = float(os.environ.get("TIMEOUT", 0))
@@ -113,5 +125,5 @@ parser.add_argument(
113
  )
114
  parser.set_defaults(taesd=USE_TAESD)
115
 
116
- config = Args(**vars(parser.parse_args()))
117
  config.pretty_print()
 
1
+ from pydantic import BaseModel, field_validator
2
  import argparse
3
  import os
4
+ from typing import Annotated
5
 
6
 
7
+ class Args(BaseModel):
8
  host: str
9
  port: int
10
  reload: bool
 
14
  torch_compile: bool
15
  taesd: bool
16
  pipeline: str
17
+ ssl_certfile: str | None
18
+ ssl_keyfile: str | None
19
  sfast: bool
20
  onediff: bool = False
21
  compel: bool = False
22
  debug: bool = False
23
 
24
+ def pretty_print(self) -> None:
25
  print("\n")
26
+ for field, value in self.model_dump().items():
27
  print(f"{field}: {value}")
28
  print("\n")
29
 
30
+ @field_validator("ssl_keyfile")
31
+ @classmethod
32
+ def validate_ssl_keyfile(cls, v: str | None, info) -> str | None:
33
+ """Validate that if ssl_certfile is provided, ssl_keyfile is also provided."""
34
+ ssl_certfile = info.data.get("ssl_certfile")
35
+ if ssl_certfile and not v:
36
+ raise ValueError(
37
+ "If ssl_certfile is provided, ssl_keyfile must also be provided"
38
+ )
39
+ return v
40
+
41
 
42
  MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
43
  TIMEOUT = float(os.environ.get("TIMEOUT", 0))
 
125
  )
126
  parser.set_defaults(taesd=USE_TAESD)
127
 
128
+ config = Args.model_validate(vars(parser.parse_args()))
129
  config.pretty_print()
server/connection_manager.py CHANGED
@@ -1,6 +1,7 @@
1
  from uuid import UUID
2
  import asyncio
3
  from fastapi import WebSocket
 
4
  from starlette.websockets import WebSocketState
5
  import logging
6
  from typing import Any
@@ -73,44 +74,157 @@ class ConnectionManager:
73
  def get_user_count(self) -> int:
74
  return len(self.active_connections)
75
 
76
- def get_websocket(self, user_id: UUID) -> WebSocket:
77
  user_session = self.active_connections.get(user_id)
78
  if user_session:
79
  websocket = user_session["websocket"]
80
- if websocket.client_state == WebSocketState.CONNECTED:
 
 
 
81
  return user_session["websocket"]
82
  return None
83
 
84
  async def disconnect(self, user_id: UUID):
85
- websocket = self.get_websocket(user_id)
86
- if websocket:
87
- await websocket.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  self.delete_user(user_id)
89
 
90
  async def send_json(self, user_id: UUID, data: dict):
91
  try:
92
  websocket = self.get_websocket(user_id)
93
  if websocket:
94
- await websocket.send_json(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  except Exception as e:
96
  logging.error(f"Error: Send json: {e}")
 
 
 
97
 
98
  async def receive_json(self, user_id: UUID) -> dict | None:
99
  try:
100
  websocket = self.get_websocket(user_id)
101
  if websocket:
102
- return await websocket.receive_json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  return None
104
  except Exception as e:
105
  logging.error(f"Error: Receive json: {e}")
 
 
 
106
  return None
107
 
108
  async def receive_bytes(self, user_id: UUID) -> bytes | None:
109
  try:
110
  websocket = self.get_websocket(user_id)
111
  if websocket:
112
- return await websocket.receive_bytes()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  return None
114
  except Exception as e:
115
  logging.error(f"Error: Receive bytes: {e}")
 
 
 
116
  return None
 
1
  from uuid import UUID
2
  import asyncio
3
  from fastapi import WebSocket
4
+ from fastapi.websockets import WebSocketDisconnect
5
  from starlette.websockets import WebSocketState
6
  import logging
7
  from typing import Any
 
74
  def get_user_count(self) -> int:
75
  return len(self.active_connections)
76
 
77
+ def get_websocket(self, user_id: UUID) -> WebSocket | None:
78
  user_session = self.active_connections.get(user_id)
79
  if user_session:
80
  websocket = user_session["websocket"]
81
+ # Both client_state and application_state should be checked
82
+ # to ensure the websocket is fully connected and not closing
83
+ if (websocket.client_state == WebSocketState.CONNECTED and
84
+ websocket.application_state == WebSocketState.CONNECTED):
85
  return user_session["websocket"]
86
  return None
87
 
88
  async def disconnect(self, user_id: UUID):
89
+ # First check if user is in active connections
90
+ if user_id not in self.active_connections:
91
+ return
92
+
93
+ # Get the websocket directly from active_connections to avoid get_websocket validation
94
+ user_session = self.active_connections.get(user_id)
95
+ if user_session and "websocket" in user_session:
96
+ websocket = user_session["websocket"]
97
+ try:
98
+ # Only attempt close if not already closed
99
+ if (websocket.client_state != WebSocketState.DISCONNECTED and
100
+ websocket.application_state != WebSocketState.DISCONNECTED):
101
+ await websocket.close()
102
+ except Exception as e:
103
+ logging.error(f"Error closing websocket for {user_id}: {e}")
104
+
105
+ # Always delete the user to ensure cleanup
106
  self.delete_user(user_id)
107
 
108
  async def send_json(self, user_id: UUID, data: dict):
109
  try:
110
  websocket = self.get_websocket(user_id)
111
  if websocket:
112
+ try:
113
+ await websocket.send_json(data)
114
+ except RuntimeError as e:
115
+ error_msg = str(e)
116
+ if any(err in error_msg for err in [
117
+ "WebSocket is not connected",
118
+ "Cannot call \"send\" once a close message has been sent",
119
+ "Cannot call \"receive\" once a close message has been sent",
120
+ "WebSocket is disconnected"]):
121
+ # The websocket was disconnected or is closing
122
+ logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
123
+ await self.disconnect(user_id)
124
+ else:
125
+ logging.error(f"Runtime error in send_json: {e}")
126
+ except WebSocketDisconnect as disconnect_error:
127
+ # Handle websocket disconnection event
128
+ code = disconnect_error.code
129
+ if code == 1006: # ABNORMAL_CLOSURE
130
+ logging.info(f"WebSocket abnormally closed for user {user_id} during send: Connection was closed without a proper close handshake")
131
+ else:
132
+ logging.info(f"WebSocket disconnected for user {user_id} with code {code} during send: {disconnect_error.reason}")
133
+
134
+ # Always disconnect the user
135
+ if user_id in self.active_connections:
136
+ await self.disconnect(user_id)
137
  except Exception as e:
138
  logging.error(f"Error: Send json: {e}")
139
+ # If any send fails, ensure the user gets removed to prevent further errors
140
+ if user_id in self.active_connections:
141
+ await self.disconnect(user_id)
142
 
143
  async def receive_json(self, user_id: UUID) -> dict | None:
144
  try:
145
  websocket = self.get_websocket(user_id)
146
  if websocket:
147
+ try:
148
+ # Receive the raw message and handle JSON parsing manually for better error handling
149
+ try:
150
+ data = await websocket.receive_json()
151
+ # Verify it's a dictionary
152
+ if not isinstance(data, dict):
153
+ logging.error(f"Expected dict but received {type(data)} from user {user_id}: {data}")
154
+ return None
155
+ return data
156
+ except ValueError as json_err:
157
+ # Specific handling for JSON parsing errors
158
+ logging.error(f"JSON parsing error for user {user_id}: {json_err}")
159
+ return None
160
+ except RuntimeError as e:
161
+ error_msg = str(e)
162
+ if any(err in error_msg for err in [
163
+ "WebSocket is not connected",
164
+ "Cannot call \"send\" once a close message has been sent",
165
+ "Cannot call \"receive\" once a close message has been sent",
166
+ "WebSocket is disconnected"]):
167
+ # The websocket was disconnected or closing
168
+ logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
169
+ await self.disconnect(user_id)
170
+ else:
171
+ logging.error(f"Runtime error in receive_json: {e}")
172
+ return None
173
+ return None
174
+ except WebSocketDisconnect as disconnect_error:
175
+ # Handle websocket disconnection event (this is a clean, expected path)
176
+ code = disconnect_error.code
177
+ if code == 1006: # ABNORMAL_CLOSURE
178
+ logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
179
+ else:
180
+ logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
181
+
182
+ # Always disconnect the user
183
+ if user_id in self.active_connections:
184
+ await self.disconnect(user_id)
185
  return None
186
  except Exception as e:
187
  logging.error(f"Error: Receive json: {e}")
188
+ # Ensure disconnection on any exception
189
+ if user_id in self.active_connections:
190
+ await self.disconnect(user_id)
191
  return None
192
 
193
  async def receive_bytes(self, user_id: UUID) -> bytes | None:
194
  try:
195
  websocket = self.get_websocket(user_id)
196
  if websocket:
197
+ try:
198
+ return await websocket.receive_bytes()
199
+ except RuntimeError as e:
200
+ error_msg = str(e)
201
+ if any(err in error_msg for err in [
202
+ "WebSocket is not connected",
203
+ "Cannot call \"send\" once a close message has been sent",
204
+ "Cannot call \"receive\" once a close message has been sent",
205
+ "WebSocket is disconnected"]):
206
+ # The websocket was disconnected or closing
207
+ logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
208
+ await self.disconnect(user_id)
209
+ else:
210
+ logging.error(f"Runtime error in receive_bytes: {e}")
211
+ return None
212
+ return None
213
+ except WebSocketDisconnect as disconnect_error:
214
+ # Handle websocket disconnection event (this is a clean, expected path)
215
+ code = disconnect_error.code
216
+ if code == 1006: # ABNORMAL_CLOSURE
217
+ logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
218
+ else:
219
+ logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
220
+
221
+ # Always disconnect the user
222
+ if user_id in self.active_connections:
223
+ await self.disconnect(user_id)
224
  return None
225
  except Exception as e:
226
  logging.error(f"Error: Receive bytes: {e}")
227
+ # Ensure disconnection on any exception
228
+ if user_id in self.active_connections:
229
+ await self.disconnect(user_id)
230
  return None
server/main.py CHANGED
@@ -1,4 +1,5 @@
1
- from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
 
2
  from fastapi.responses import StreamingResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.staticfiles import StaticFiles
@@ -9,37 +10,38 @@ from PIL import Image
9
  import logging
10
  from config import config, Args
11
  from connection_manager import ConnectionManager, ServerFullException
12
- import uuid
13
  from uuid import UUID
14
  import time
15
- from typing import Any, Protocol, runtime_checkable
16
  from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class, ParamsModel
17
  from device import device, torch_dtype
18
  import asyncio
19
  import os
20
  import time
21
- import torch
22
- from pydantic import BaseModel, create_model
 
 
 
 
 
 
 
23
 
24
  @runtime_checkable
25
  class BasePipeline(Protocol):
26
  class Info:
27
  @classmethod
28
- def schema(cls) -> dict[str, Any]:
29
- ...
30
  page_content: str | None
31
  input_mode: str
32
-
33
  class InputParams(ParamsModel):
34
  @classmethod
35
- def schema(cls) -> dict[str, Any]:
36
- ...
37
-
38
- def dict(self) -> dict[str, Any]:
39
- ...
40
-
41
- def predict(self, params: ParamsModel) -> Image.Image | None:
42
- ...
43
 
44
 
45
  THROTTLE = 1.0 / 120
@@ -74,7 +76,22 @@ class App:
74
  await handle_websocket_data(user_id)
75
  except ServerFullException as e:
76
  logging.error(f"Server Full: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  finally:
 
78
  await self.conn_manager.disconnect(user_id)
79
  logging.info(f"User disconnected: {user_id}")
80
 
@@ -99,34 +116,73 @@ class App:
99
  return
100
  data = await self.conn_manager.receive_json(user_id)
101
  if data is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  continue
103
-
104
  if data["status"] == "next_frame":
105
  info = self.pipeline.Info()
106
  params_data = await self.conn_manager.receive_json(user_id)
107
  if params_data is None:
 
 
 
 
 
 
108
  continue
109
-
110
  params = self.pipeline.InputParams.model_validate(params_data)
111
-
112
  if info.input_mode == "image":
113
  image_data = await self.conn_manager.receive_bytes(user_id)
114
- if image_data is None or len(image_data) == 0:
 
 
 
 
 
 
115
  await self.conn_manager.send_json(
116
  user_id, {"status": "send_frame"}
117
  )
118
  continue
119
-
120
- # Create a new Pydantic model with the image field
121
- params_dict = params.model_dump()
122
- params_dict["image"] = bytes_to_pil(image_data)
123
- params = self.pipeline.InputParams.model_validate(params_dict)
 
 
 
 
124
 
125
  await self.conn_manager.update_data(user_id, params)
126
  await self.conn_manager.send_json(user_id, {"status": "wait"})
127
 
 
 
 
 
 
 
 
 
 
 
128
  except Exception as e:
129
- logging.error(f"Websocket Error: {e}, {user_id} ")
130
  await self.conn_manager.disconnect(user_id)
131
 
132
  @self.app.get("/api/queue")
@@ -137,25 +193,58 @@ class App:
137
  @self.app.get("/api/stream/{user_id}")
138
  async def stream(user_id: UUID, request: Request) -> StreamingResponse:
139
  try:
140
- async def generate() -> bytes:
 
141
  last_params: ParamsModel | None = None
142
  while True:
 
 
 
 
 
143
  last_time = time.time()
144
- await self.conn_manager.send_json(
145
- user_id, {"status": "send_frame"}
146
- )
 
 
 
 
 
 
 
 
 
 
147
  params = await self.conn_manager.get_latest_data(user_id)
148
-
149
- if (params is None or
150
- (last_params is not None and
151
- params.model_dump() == last_params.model_dump())):
152
  await asyncio.sleep(THROTTLE)
153
  continue
154
-
155
- last_params = params
156
- image = self.pipeline.predict(params)
157
 
158
- if self.args.safety_checker and self.safety_checker is not None and image is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  image, has_nsfw_concept = self.safety_checker(image)
160
  if has_nsfw_concept:
161
  image = None
@@ -175,9 +264,34 @@ class App:
175
  media_type="multipart/x-mixed-replace;boundary=frame",
176
  headers={"Cache-Control": "no-cache"},
177
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  except Exception as e:
179
- logging.error(f"Streaming Error: {e}, {user_id} ")
180
- raise HTTPException(status_code=404, detail="User not found")
 
 
181
 
182
  # route to setup frontend
183
  @self.app.get("/api/settings")
@@ -185,7 +299,7 @@ class App:
185
  info_schema = self.pipeline.Info.schema()
186
  info = self.pipeline.Info()
187
  page_content = ""
188
- if hasattr(info, 'page_content') and info.page_content:
189
  page_content = markdown2.markdown(info.page_content)
190
 
191
  input_params = self.pipeline.InputParams.schema()
@@ -234,7 +348,7 @@ if __name__ == "__main__":
234
  # app = create_app(config) # Create the app once
235
 
236
  uvicorn.run(
237
- app,
238
  host=config.host,
239
  port=config.port,
240
  reload=config.reload,
 
1
+ from fastapi import FastAPI, WebSocket, HTTPException
2
+ from fastapi.websockets import WebSocketDisconnect
3
  from fastapi.responses import StreamingResponse, JSONResponse
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.staticfiles import StaticFiles
 
10
  import logging
11
  from config import config, Args
12
  from connection_manager import ConnectionManager, ServerFullException
 
13
  from uuid import UUID
14
  import time
15
+ from typing import Any, Protocol, runtime_checkable, AsyncGenerator
16
  from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class, ParamsModel
17
  from device import device, torch_dtype
18
  import asyncio
19
  import os
20
  import time
21
+
22
+ # Common WebSocket error messages that indicate disconnection
23
+ ERROR_MESSAGES = [
24
+ "WebSocket is not connected",
25
+ 'Cannot call "send" once a close message has been sent',
26
+ 'Cannot call "receive" once a close message has been sent',
27
+ "WebSocket is disconnected",
28
+ ]
29
+
30
 
31
  @runtime_checkable
32
  class BasePipeline(Protocol):
33
  class Info:
34
  @classmethod
35
+ def schema(cls) -> dict[str, Any]: ...
36
+
37
  page_content: str | None
38
  input_mode: str
39
+
40
  class InputParams(ParamsModel):
41
  @classmethod
42
+ def schema(cls) -> dict[str, Any]: ...
43
+
44
+ def predict(self, params: ParamsModel) -> Image.Image | None: ...
 
 
 
 
 
45
 
46
 
47
  THROTTLE = 1.0 / 120
 
76
  await handle_websocket_data(user_id)
77
  except ServerFullException as e:
78
  logging.error(f"Server Full: {e}")
79
+ except WebSocketDisconnect as disconnect_error:
80
+ # Handle websocket disconnection event
81
+ code = disconnect_error.code
82
+ if code == 1006: # ABNORMAL_CLOSURE
83
+ logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
84
+ else:
85
+ logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
86
+ except RuntimeError as e:
87
+ if any(err in str(e) for err in ERROR_MESSAGES):
88
+ logging.info(f"WebSocket disconnected for user {user_id}: {e}")
89
+ else:
90
+ logging.error(f"Runtime error in websocket endpoint: {e}")
91
+ except Exception as e:
92
+ logging.error(f"Unexpected error in websocket endpoint: {e}")
93
  finally:
94
+ # Always ensure we disconnect the user
95
  await self.conn_manager.disconnect(user_id)
96
  logging.info(f"User disconnected: {user_id}")
97
 
 
116
  return
117
  data = await self.conn_manager.receive_json(user_id)
118
  if data is None:
119
+ # Check if the user is still connected - they might have disconnected
120
+ if not self.conn_manager.check_user(user_id):
121
+ logging.info(
122
+ f"User {user_id} disconnected, exiting handle_websocket_data loop"
123
+ )
124
+ return
125
+ continue
126
+
127
+ # Validate that data is a dictionary and has a status field
128
+ if not isinstance(data, dict) or "status" not in data:
129
+ logging.error(
130
+ f"Invalid data format received from user {user_id}: {data}"
131
+ )
132
  continue
133
+
134
  if data["status"] == "next_frame":
135
  info = self.pipeline.Info()
136
  params_data = await self.conn_manager.receive_json(user_id)
137
  if params_data is None:
138
+ # Check if the user is still connected
139
+ if not self.conn_manager.check_user(user_id):
140
+ logging.info(
141
+ f"User {user_id} disconnected during params reception"
142
+ )
143
+ return
144
  continue
145
+
146
  params = self.pipeline.InputParams.model_validate(params_data)
147
+
148
  if info.input_mode == "image":
149
  image_data = await self.conn_manager.receive_bytes(user_id)
150
+ if image_data is None:
151
+ # Check if the user is still connected
152
+ if not self.conn_manager.check_user(user_id):
153
+ logging.info(
154
+ f"User {user_id} disconnected during image reception"
155
+ )
156
+ return
157
  await self.conn_manager.send_json(
158
  user_id, {"status": "send_frame"}
159
  )
160
  continue
161
+ if len(image_data) == 0:
162
+ await self.conn_manager.send_json(
163
+ user_id, {"status": "send_frame"}
164
+ )
165
+ continue
166
+
167
+ # Add the image directly to the model using setattr
168
+ # This works because we've configured the ParamsModel to allow extra fields
169
+ setattr(params, "image", bytes_to_pil(image_data))
170
 
171
  await self.conn_manager.update_data(user_id, params)
172
  await self.conn_manager.send_json(user_id, {"status": "wait"})
173
 
174
+ except RuntimeError as e:
175
+ error_msg = str(e)
176
+ if any(err in error_msg for err in ERROR_MESSAGES):
177
+ logging.info(
178
+ f"WebSocket disconnected for user {user_id}: {error_msg}"
179
+ )
180
+ else:
181
+ logging.error(f"Websocket Runtime Error: {e}, {user_id}")
182
+ # Ensure disconnect is called
183
+ await self.conn_manager.disconnect(user_id)
184
  except Exception as e:
185
+ logging.error(f"Websocket Error: {e}, {user_id}")
186
  await self.conn_manager.disconnect(user_id)
187
 
188
  @self.app.get("/api/queue")
 
193
  @self.app.get("/api/stream/{user_id}")
194
  async def stream(user_id: UUID, request: Request) -> StreamingResponse:
195
  try:
196
+
197
+ async def generate() -> AsyncGenerator[bytes, None]:
198
  last_params: ParamsModel | None = None
199
  while True:
200
+ # Check if the user is still connected
201
+ if not self.conn_manager.check_user(user_id):
202
+ logging.info(f"User {user_id} disconnected from stream")
203
+ break
204
+
205
  last_time = time.time()
206
+ try:
207
+ await self.conn_manager.send_json(
208
+ user_id, {"status": "send_frame"}
209
+ )
210
+ except Exception as e:
211
+ logging.error(f"Error sending to websocket in stream: {e}")
212
+ # User might have disconnected
213
+ if not self.conn_manager.check_user(user_id):
214
+ logging.info(f"User {user_id} disconnected from stream")
215
+ break
216
+ await asyncio.sleep(THROTTLE)
217
+ continue
218
+
219
  params = await self.conn_manager.get_latest_data(user_id)
220
+
221
+ if params is None:
 
 
222
  await asyncio.sleep(THROTTLE)
223
  continue
 
 
 
224
 
225
+ try:
226
+ # Check if the params haven't changed since last time
227
+ if (
228
+ last_params is not None
229
+ and params.model_dump() == last_params.model_dump()
230
+ ):
231
+ await asyncio.sleep(THROTTLE)
232
+ continue
233
+
234
+ last_params = params
235
+ image = self.pipeline.predict(params)
236
+ except Exception as e:
237
+ logging.error(
238
+ f"Error processing params for user {user_id}: {e}"
239
+ )
240
+ await asyncio.sleep(THROTTLE)
241
+ continue
242
+
243
+ if (
244
+ self.args.safety_checker
245
+ and self.safety_checker is not None
246
+ and image is not None
247
+ ):
248
  image, has_nsfw_concept = self.safety_checker(image)
249
  if has_nsfw_concept:
250
  image = None
 
264
  media_type="multipart/x-mixed-replace;boundary=frame",
265
  headers={"Cache-Control": "no-cache"},
266
  )
267
+ except WebSocketDisconnect as disconnect_error:
268
+ # Handle websocket disconnection event
269
+ code = disconnect_error.code
270
+ if code == 1006: # ABNORMAL_CLOSURE
271
+ logging.info(f"WebSocket abnormally closed during streaming for user {user_id}: Connection was closed without a proper close handshake")
272
+ else:
273
+ logging.info(f"WebSocket disconnected during streaming for user {user_id} with code {code}: {disconnect_error.reason}")
274
+
275
+ # Clean disconnection without error response
276
+ await self.conn_manager.disconnect(user_id)
277
+ raise HTTPException(status_code=204, detail="Connection closed")
278
+ except RuntimeError as e:
279
+ error_msg = str(e)
280
+ if any(err in error_msg for err in ERROR_MESSAGES):
281
+ logging.info(
282
+ f"WebSocket disconnected during streaming for user {user_id}: {error_msg}"
283
+ )
284
+ # Clean disconnection without error response
285
+ await self.conn_manager.disconnect(user_id)
286
+ raise HTTPException(status_code=204, detail="Connection closed")
287
+ else:
288
+ logging.error(f"Streaming Runtime Error: {e}, {user_id}")
289
+ raise HTTPException(status_code=500, detail="Streaming error")
290
  except Exception as e:
291
+ logging.error(f"Streaming Error: {e}, {user_id}")
292
+ # Always ensure we disconnect the user on error
293
+ await self.conn_manager.disconnect(user_id)
294
+ raise HTTPException(status_code=500, detail="Streaming error")
295
 
296
  # route to setup frontend
297
  @self.app.get("/api/settings")
 
299
  info_schema = self.pipeline.Info.schema()
300
  info = self.pipeline.Info()
301
  page_content = ""
302
+ if hasattr(info, "page_content") and info.page_content:
303
  page_content = markdown2.markdown(info.page_content)
304
 
305
  input_params = self.pipeline.InputParams.schema()
 
348
  # app = create_app(config) # Create the app once
349
 
350
  uvicorn.run(
351
+ "main:app",
352
  host=config.host,
353
  port=config.port,
354
  reload=config.reload,
server/pipelines/img2img.py CHANGED
@@ -14,6 +14,7 @@ import psutil
14
  from config import Args
15
  from pydantic import BaseModel, Field
16
  from PIL import Image
 
17
  import math
18
 
19
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
@@ -54,7 +55,7 @@ class Pipeline:
54
  input_mode: str = "image"
55
  page_content: str = page_content
56
 
57
- class InputParams(BaseModel):
58
  prompt: str = Field(
59
  default_prompt,
60
  title="Prompt",
 
14
  from config import Args
15
  from pydantic import BaseModel, Field
16
  from PIL import Image
17
+ from util import ParamsModel
18
  import math
19
 
20
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
 
55
  input_mode: str = "image"
56
  page_content: str = page_content
57
 
58
+ class InputParams(ParamsModel):
59
  prompt: str = Field(
60
  default_prompt,
61
  title="Prompt",
server/pipelines/img2imgFlux.py CHANGED
@@ -28,6 +28,7 @@ from config import Args
28
  from pydantic import BaseModel, Field
29
  from PIL import Image
30
  from pathlib import Path
 
31
  import math
32
  import gc
33
 
@@ -61,7 +62,7 @@ class Pipeline:
61
  input_mode: str = "image"
62
  page_content: str = page_content
63
 
64
- class InputParams(BaseModel):
65
  prompt: str = Field(
66
  default_prompt,
67
  title="Prompt",
 
28
  from pydantic import BaseModel, Field
29
  from PIL import Image
30
  from pathlib import Path
31
+ from util import ParamsModel
32
  import math
33
  import gc
34
 
 
62
  input_mode: str = "image"
63
  page_content: str = page_content
64
 
65
+ class InputParams(ParamsModel):
66
  prompt: str = Field(
67
  default_prompt,
68
  title="Prompt",
server/pipelines/img2imgSDTurbo.py CHANGED
@@ -9,10 +9,10 @@ try:
9
  except:
10
  pass
11
 
12
- import psutil
13
  from config import Args
14
  from pydantic import BaseModel, Field
15
  from PIL import Image
 
16
  import math
17
 
18
 
@@ -55,7 +55,7 @@ class Pipeline:
55
  input_mode: str = "image"
56
  page_content: str = page_content
57
 
58
- class InputParams(BaseModel):
59
  prompt: str = Field(
60
  default_prompt,
61
  title="Prompt",
 
9
  except:
10
  pass
11
 
 
12
  from config import Args
13
  from pydantic import BaseModel, Field
14
  from PIL import Image
15
+ from util import ParamsModel
16
  import math
17
 
18
 
 
55
  input_mode: str = "image"
56
  page_content: str = page_content
57
 
58
+ class InputParams(ParamsModel):
59
  prompt: str = Field(
60
  default_prompt,
61
  title="Prompt",
server/pipelines/img2imgSDXL-Lightning.py CHANGED
@@ -18,6 +18,7 @@ from huggingface_hub import hf_hub_download
18
  from config import Args
19
  from pydantic import BaseModel, Field
20
  from PIL import Image
 
21
  import math
22
 
23
  base = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -62,7 +63,7 @@ class Pipeline:
62
  input_mode: str = "image"
63
  page_content: str = page_content
64
 
65
- class InputParams(BaseModel):
66
  prompt: str = Field(
67
  default_prompt,
68
  title="Prompt",
 
18
  from config import Args
19
  from pydantic import BaseModel, Field
20
  from PIL import Image
21
+ from util import ParamsModel
22
  import math
23
 
24
  base = "stabilityai/stable-diffusion-xl-base-1.0"
 
63
  input_mode: str = "image"
64
  page_content: str = page_content
65
 
66
+ class InputParams(ParamsModel):
67
  prompt: str = Field(
68
  default_prompt,
69
  title="Prompt",
server/pipelines/img2imgSDXLTurbo.py CHANGED
@@ -14,6 +14,7 @@ import psutil
14
  from config import Args
15
  from pydantic import BaseModel, Field
16
  from PIL import Image
 
17
  import math
18
 
19
  base_model = "stabilityai/sdxl-turbo"
@@ -55,7 +56,7 @@ class Pipeline:
55
  input_mode: str = "image"
56
  page_content: str = page_content
57
 
58
- class InputParams(BaseModel):
59
  prompt: str = Field(
60
  default_prompt,
61
  title="Prompt",
 
14
  from config import Args
15
  from pydantic import BaseModel, Field
16
  from PIL import Image
17
+ from util import ParamsModel
18
  import math
19
 
20
  base_model = "stabilityai/sdxl-turbo"
 
56
  input_mode: str = "image"
57
  page_content: str = page_content
58
 
59
+ class InputParams(ParamsModel):
60
  prompt: str = Field(
61
  default_prompt,
62
  title="Prompt",
server/pipelines/img2imgSDXS512.py CHANGED
@@ -11,6 +11,7 @@ import psutil
11
  from config import Args
12
  from pydantic import BaseModel, Field
13
  from PIL import Image
 
14
  import math
15
 
16
  base_model = "IDKiro/sdxs-512-0.9"
@@ -51,7 +52,7 @@ class Pipeline:
51
  input_mode: str = "image"
52
  page_content: str = page_content
53
 
54
- class InputParams(BaseModel):
55
  prompt: str = Field(
56
  default_prompt,
57
  title="Prompt",
 
11
  from config import Args
12
  from pydantic import BaseModel, Field
13
  from PIL import Image
14
+ from util import ParamsModel
15
  import math
16
 
17
  base_model = "IDKiro/sdxs-512-0.9"
 
52
  input_mode: str = "image"
53
  page_content: str = page_content
54
 
55
+ class InputParams(ParamsModel):
56
  prompt: str = Field(
57
  default_prompt,
58
  title="Prompt",
server/util.py CHANGED
@@ -1,5 +1,5 @@
1
  from importlib import import_module
2
- from typing import Any, TypeVar, type_check_only
3
  from PIL import Image
4
  import io
5
  from pydantic import BaseModel
@@ -11,12 +11,17 @@ TPipeline = TypeVar("TPipeline", bound=type[Any])
11
 
12
  class ParamsModel(BaseModel):
13
  """Base model for pipeline parameters."""
14
-
 
 
 
 
 
15
  @classmethod
16
- def from_dict(cls, data: dict[str, Any]) -> 'ParamsModel':
17
  """Create a model instance from dictionary data."""
18
  return cls.model_validate(data)
19
-
20
  def to_dict(self) -> dict[str, Any]:
21
  """Convert model to dictionary."""
22
  return self.model_dump()
@@ -25,13 +30,13 @@ class ParamsModel(BaseModel):
25
  def get_pipeline_class(pipeline_name: str) -> type:
26
  """
27
  Dynamically imports and returns the Pipeline class from a specified module.
28
-
29
  Args:
30
  pipeline_name: The name of the pipeline module to import
31
-
32
  Returns:
33
  The Pipeline class from the specified module
34
-
35
  Raises:
36
  ValueError: If the module or Pipeline class isn't found
37
  TypeError: If Pipeline is not a class
 
1
  from importlib import import_module
2
+ from typing import Any, TypeVar
3
  from PIL import Image
4
  import io
5
  from pydantic import BaseModel
 
11
 
12
  class ParamsModel(BaseModel):
13
  """Base model for pipeline parameters."""
14
+
15
+ model_config = {
16
+ "arbitrary_types_allowed": True,
17
+ "extra": "allow", # Allow extra attributes for dynamic fields like 'image'
18
+ }
19
+
20
  @classmethod
21
+ def from_dict(cls, data: dict[str, Any]) -> "ParamsModel":
22
  """Create a model instance from dictionary data."""
23
  return cls.model_validate(data)
24
+
25
  def to_dict(self) -> dict[str, Any]:
26
  """Convert model to dictionary."""
27
  return self.model_dump()
 
30
  def get_pipeline_class(pipeline_name: str) -> type:
31
  """
32
  Dynamically imports and returns the Pipeline class from a specified module.
33
+
34
  Args:
35
  pipeline_name: The name of the pipeline module to import
36
+
37
  Returns:
38
  The Pipeline class from the specified module
39
+
40
  Raises:
41
  ValueError: If the module or Pipeline class isn't found
42
  TypeError: If Pipeline is not a class