Spaces:
Running
on
A100
Running
on
A100
refactor: fix websocket errors
Browse files- frontend/src/lib/components/ImagePlayer.svelte +24 -3
- frontend/src/lib/lcmLive.ts +159 -48
- frontend/src/routes/+page.svelte +79 -15
- server/config.py +19 -7
- server/connection_manager.py +122 -8
- server/main.py +156 -42
- server/pipelines/img2img.py +2 -1
- server/pipelines/img2imgFlux.py +2 -1
- server/pipelines/img2imgSDTurbo.py +2 -2
- server/pipelines/img2imgSDXL-Lightning.py +2 -1
- server/pipelines/img2imgSDXLTurbo.py +2 -1
- server/pipelines/img2imgSDXS512.py +2 -1
- server/util.py +12 -7
frontend/src/lib/components/ImagePlayer.svelte
CHANGED
@@ -7,8 +7,8 @@
|
|
7 |
import Expand from '$lib/icons/expand.svelte';
|
8 |
import { snapImage, expandWindow } from '$lib/utils';
|
9 |
|
10 |
-
$: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED
|
11 |
-
|
12 |
let imageEl: HTMLImageElement;
|
13 |
let expandedWindow: Window;
|
14 |
let isExpanded = false;
|
@@ -40,12 +40,26 @@
|
|
40 |
class="relative mx-auto aspect-square max-w-lg self-center overflow-hidden rounded-lg border border-slate-300"
|
41 |
>
|
42 |
<!-- svelte-ignore a11y-missing-attribute -->
|
43 |
-
{#if
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
{#if !isExpanded}
|
|
|
45 |
<img
|
46 |
bind:this={imageEl}
|
47 |
class="aspect-square w-full rounded-lg"
|
48 |
src={'/api/stream/' + $streamId}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
/>
|
50 |
{/if}
|
51 |
<div class="absolute bottom-1 right-1">
|
@@ -65,6 +79,13 @@
|
|
65 |
<Floppy classList={''} />
|
66 |
</Button>
|
67 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
{:else}
|
69 |
<img
|
70 |
class="aspect-square w-full rounded-lg"
|
|
|
7 |
import Expand from '$lib/icons/expand.svelte';
|
8 |
import { snapImage, expandWindow } from '$lib/utils';
|
9 |
|
10 |
+
$: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED &&
|
11 |
+
$lcmLiveStatus !== LCMLiveStatus.ERROR;
|
12 |
let imageEl: HTMLImageElement;
|
13 |
let expandedWindow: Window;
|
14 |
let isExpanded = false;
|
|
|
40 |
class="relative mx-auto aspect-square max-w-lg self-center overflow-hidden rounded-lg border border-slate-300"
|
41 |
>
|
42 |
<!-- svelte-ignore a11y-missing-attribute -->
|
43 |
+
{#if $lcmLiveStatus === LCMLiveStatus.CONNECTING}
|
44 |
+
<!-- Show connecting spinner -->
|
45 |
+
<div class="flex items-center justify-center h-full w-full">
|
46 |
+
<div class="animate-spin rounded-full h-16 w-16 border-b-2 border-white"></div>
|
47 |
+
<p class="text-white ml-2">Connecting...</p>
|
48 |
+
</div>
|
49 |
+
{:else if isLCMRunning}
|
50 |
{#if !isExpanded}
|
51 |
+
<!-- Handle image error by adding onerror event -->
|
52 |
<img
|
53 |
bind:this={imageEl}
|
54 |
class="aspect-square w-full rounded-lg"
|
55 |
src={'/api/stream/' + $streamId}
|
56 |
+
on:error={(e) => {
|
57 |
+
console.error('Image stream error:', e);
|
58 |
+
// If stream fails to load, set status to error
|
59 |
+
if ($lcmLiveStatus !== LCMLiveStatus.ERROR) {
|
60 |
+
lcmLiveStatus.set(LCMLiveStatus.ERROR);
|
61 |
+
}
|
62 |
+
}}
|
63 |
/>
|
64 |
{/if}
|
65 |
<div class="absolute bottom-1 right-1">
|
|
|
79 |
<Floppy classList={''} />
|
80 |
</Button>
|
81 |
</div>
|
82 |
+
{:else if $lcmLiveStatus === LCMLiveStatus.ERROR}
|
83 |
+
<!-- Show error state with red border -->
|
84 |
+
<div class="flex items-center justify-center h-full w-full border-2 border-red-500 rounded-lg bg-gray-900">
|
85 |
+
<p class="text-center text-white p-4">
|
86 |
+
Connection error
|
87 |
+
</p>
|
88 |
+
</div>
|
89 |
{:else}
|
90 |
<img
|
91 |
class="aspect-square w-full rounded-lg"
|
frontend/src/lib/lcmLive.ts
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
-
import { writable } from 'svelte/store';
|
2 |
|
3 |
export enum LCMLiveStatus {
|
4 |
CONNECTED = 'connected',
|
5 |
DISCONNECTED = 'disconnected',
|
|
|
6 |
WAIT = 'wait',
|
7 |
SEND_FRAME = 'send_frame',
|
8 |
-
TIMEOUT = 'timeout'
|
|
|
9 |
}
|
10 |
|
11 |
const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
|
@@ -13,85 +15,194 @@ const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
|
|
13 |
export const lcmLiveStatus = writable<LCMLiveStatus>(initStatus);
|
14 |
export const streamId = writable<string | null>(null);
|
15 |
|
|
|
16 |
let websocket: WebSocket | null = null;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
export const lcmLiveActions = {
|
18 |
async start(getSreamdata: () => any[]) {
|
19 |
return new Promise((resolve, reject) => {
|
20 |
try {
|
|
|
|
|
|
|
21 |
const userId = crypto.randomUUID();
|
22 |
const websocketURL = `${
|
23 |
window.location.protocol === 'https:' ? 'wss' : 'ws'
|
24 |
}:${window.location.host}/api/ws/${userId}`;
|
25 |
|
|
|
|
|
|
|
|
|
|
|
26 |
websocket = new WebSocket(websocketURL);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
websocket.onopen = () => {
|
|
|
28 |
console.log('Connected to websocket');
|
29 |
};
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
};
|
|
|
34 |
websocket.onerror = (err) => {
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
};
|
|
|
37 |
websocket.onmessage = (event) => {
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
}
|
69 |
};
|
70 |
} catch (err) {
|
71 |
-
console.error(err);
|
72 |
-
lcmLiveStatus.set(LCMLiveStatus.
|
73 |
streamId.set(null);
|
74 |
reject(err);
|
75 |
}
|
76 |
});
|
77 |
},
|
78 |
send(data: Blob | { [key: string]: any }) {
|
79 |
-
|
80 |
-
if (
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
} else {
|
83 |
-
websocket
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
}
|
85 |
-
}
|
86 |
-
console.
|
|
|
|
|
87 |
}
|
88 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
async stop() {
|
90 |
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
91 |
-
|
92 |
-
websocket
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
}
|
94 |
-
websocket = null;
|
95 |
-
streamId.set(null);
|
96 |
}
|
97 |
};
|
|
|
1 |
+
import { get, writable } from 'svelte/store';
|
2 |
|
3 |
export enum LCMLiveStatus {
|
4 |
CONNECTED = 'connected',
|
5 |
DISCONNECTED = 'disconnected',
|
6 |
+
CONNECTING = 'connecting',
|
7 |
WAIT = 'wait',
|
8 |
SEND_FRAME = 'send_frame',
|
9 |
+
TIMEOUT = 'timeout',
|
10 |
+
ERROR = 'error'
|
11 |
}
|
12 |
|
13 |
const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
|
|
|
15 |
export const lcmLiveStatus = writable<LCMLiveStatus>(initStatus);
|
16 |
export const streamId = writable<string | null>(null);
|
17 |
|
18 |
+
// WebSocket connection
|
19 |
let websocket: WebSocket | null = null;
|
20 |
+
// Flag to track intentional connection closure
|
21 |
+
let intentionalClosure = false;
|
22 |
+
|
23 |
+
// Register browser unload event listener to properly close WebSockets
|
24 |
+
if (typeof window !== 'undefined') {
|
25 |
+
window.addEventListener('beforeunload', () => {
|
26 |
+
// Mark any closure during page unload as intentional
|
27 |
+
intentionalClosure = true;
|
28 |
+
// Close the WebSocket properly if it exists
|
29 |
+
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
30 |
+
websocket.close(1000, 'Page unload');
|
31 |
+
}
|
32 |
+
});
|
33 |
+
}
|
34 |
export const lcmLiveActions = {
|
35 |
async start(getSreamdata: () => any[]) {
|
36 |
return new Promise((resolve, reject) => {
|
37 |
try {
|
38 |
+
// Set connecting status immediately
|
39 |
+
lcmLiveStatus.set(LCMLiveStatus.CONNECTING);
|
40 |
+
|
41 |
const userId = crypto.randomUUID();
|
42 |
const websocketURL = `${
|
43 |
window.location.protocol === 'https:' ? 'wss' : 'ws'
|
44 |
}:${window.location.host}/api/ws/${userId}`;
|
45 |
|
46 |
+
// Close any existing connection first
|
47 |
+
if (websocket && websocket.readyState !== WebSocket.CLOSED) {
|
48 |
+
websocket.close();
|
49 |
+
}
|
50 |
+
|
51 |
websocket = new WebSocket(websocketURL);
|
52 |
+
|
53 |
+
// Set a connection timeout
|
54 |
+
const connectionTimeout = setTimeout(() => {
|
55 |
+
if (websocket && websocket.readyState !== WebSocket.OPEN) {
|
56 |
+
console.error('WebSocket connection timeout');
|
57 |
+
lcmLiveStatus.set(LCMLiveStatus.ERROR);
|
58 |
+
streamId.set(null);
|
59 |
+
reject(new Error('Connection timeout. Please try again.'));
|
60 |
+
websocket.close();
|
61 |
+
}
|
62 |
+
}, 10000); // 10 second timeout
|
63 |
+
|
64 |
websocket.onopen = () => {
|
65 |
+
clearTimeout(connectionTimeout);
|
66 |
console.log('Connected to websocket');
|
67 |
};
|
68 |
+
|
69 |
+
websocket.onclose = (event) => {
|
70 |
+
clearTimeout(connectionTimeout);
|
71 |
+
console.log(`Disconnected from websocket: ${event.code} ${event.reason}`);
|
72 |
+
|
73 |
+
// Only change status if we're not in ERROR state (which would mean we already handled the error)
|
74 |
+
if (get(lcmLiveStatus) !== LCMLiveStatus.ERROR) {
|
75 |
+
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
76 |
+
}
|
77 |
+
|
78 |
+
// If connection was never established (close without open)
|
79 |
+
if (event.code === 1006 && streamId.get() === null) {
|
80 |
+
reject(new Error('Cannot connect to server. Please try again later.'));
|
81 |
+
}
|
82 |
};
|
83 |
+
|
84 |
websocket.onerror = (err) => {
|
85 |
+
clearTimeout(connectionTimeout);
|
86 |
+
console.error('WebSocket error:', err);
|
87 |
+
lcmLiveStatus.set(LCMLiveStatus.ERROR);
|
88 |
+
streamId.set(null);
|
89 |
+
reject(new Error('Connection error. Please try again.'));
|
90 |
};
|
91 |
+
|
92 |
websocket.onmessage = (event) => {
|
93 |
+
try {
|
94 |
+
const data = JSON.parse(event.data);
|
95 |
+
switch (data.status) {
|
96 |
+
case 'connected':
|
97 |
+
lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
|
98 |
+
streamId.set(userId);
|
99 |
+
resolve({ status: 'connected', userId });
|
100 |
+
break;
|
101 |
+
case 'send_frame':
|
102 |
+
lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
|
103 |
+
try {
|
104 |
+
const streamData = getSreamdata();
|
105 |
+
// Send as an object, not a string, to use the proper handling in the send method
|
106 |
+
this.send({ status: 'next_frame' });
|
107 |
+
for (const d of streamData) {
|
108 |
+
this.send(d);
|
109 |
+
}
|
110 |
+
} catch (error) {
|
111 |
+
console.error('Error sending frame data:', error);
|
112 |
+
}
|
113 |
+
break;
|
114 |
+
case 'wait':
|
115 |
+
lcmLiveStatus.set(LCMLiveStatus.WAIT);
|
116 |
+
break;
|
117 |
+
case 'timeout':
|
118 |
+
console.log('Session timeout');
|
119 |
+
lcmLiveStatus.set(LCMLiveStatus.TIMEOUT);
|
120 |
+
streamId.set(null);
|
121 |
+
reject(new Error('Session timeout. Please restart.'));
|
122 |
+
break;
|
123 |
+
case 'error':
|
124 |
+
console.error('Server error:', data.message);
|
125 |
+
lcmLiveStatus.set(LCMLiveStatus.ERROR);
|
126 |
+
streamId.set(null);
|
127 |
+
reject(new Error(data.message || 'Server error occurred'));
|
128 |
+
break;
|
129 |
+
default:
|
130 |
+
console.log('Unknown message status:', data.status);
|
131 |
+
}
|
132 |
+
} catch (error) {
|
133 |
+
console.error('Error handling websocket message:', error);
|
134 |
}
|
135 |
};
|
136 |
} catch (err) {
|
137 |
+
console.error('Error initializing websocket:', err);
|
138 |
+
lcmLiveStatus.set(LCMLiveStatus.ERROR);
|
139 |
streamId.set(null);
|
140 |
reject(err);
|
141 |
}
|
142 |
});
|
143 |
},
|
144 |
send(data: Blob | { [key: string]: any }) {
|
145 |
+
try {
|
146 |
+
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
147 |
+
if (data instanceof Blob) {
|
148 |
+
websocket.send(data);
|
149 |
+
} else {
|
150 |
+
websocket.send(JSON.stringify(data));
|
151 |
+
}
|
152 |
} else {
|
153 |
+
const readyStateText = websocket
|
154 |
+
? ['CONNECTING', 'OPEN', 'CLOSING', 'CLOSED'][websocket.readyState]
|
155 |
+
: 'null';
|
156 |
+
console.warn(`WebSocket not ready for sending: ${readyStateText}`);
|
157 |
+
|
158 |
+
// If WebSocket is closed unexpectedly, set status to disconnected
|
159 |
+
if (!websocket || websocket.readyState === WebSocket.CLOSED) {
|
160 |
+
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
161 |
+
streamId.set(null);
|
162 |
+
}
|
163 |
}
|
164 |
+
} catch (error) {
|
165 |
+
console.error('Error sending data through WebSocket:', error);
|
166 |
+
// Handle WebSocket error by forcing disconnection
|
167 |
+
this.stop();
|
168 |
}
|
169 |
},
|
170 |
+
|
171 |
+
async reconnect(getSreamdata: () => any[]) {
|
172 |
+
try {
|
173 |
+
await this.stop();
|
174 |
+
// Small delay to ensure clean disconnection before reconnecting
|
175 |
+
await new Promise((resolve) => setTimeout(resolve, 500));
|
176 |
+
return await this.start(getSreamdata);
|
177 |
+
} catch (error) {
|
178 |
+
console.error('Reconnection failed:', error);
|
179 |
+
throw error;
|
180 |
+
}
|
181 |
+
},
|
182 |
+
|
183 |
async stop() {
|
184 |
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
185 |
+
try {
|
186 |
+
if (websocket) {
|
187 |
+
// Only attempt to close if not already closed
|
188 |
+
if (websocket.readyState !== WebSocket.CLOSED) {
|
189 |
+
// Set up onclose handler to clean up only
|
190 |
+
websocket.onclose = () => {
|
191 |
+
console.log('WebSocket closed cleanly during stop()');
|
192 |
+
};
|
193 |
+
|
194 |
+
// Set up onerror to be silent during intentional closure
|
195 |
+
websocket.onerror = () => {};
|
196 |
+
|
197 |
+
websocket.close(1000, 'Client initiated disconnect');
|
198 |
+
}
|
199 |
+
}
|
200 |
+
} catch (error) {
|
201 |
+
console.error('Error during WebSocket closure:', error);
|
202 |
+
} finally {
|
203 |
+
// Always clean up references
|
204 |
+
websocket = null;
|
205 |
+
streamId.set(null);
|
206 |
}
|
|
|
|
|
207 |
}
|
208 |
};
|
frontend/src/routes/+page.svelte
CHANGED
@@ -57,35 +57,86 @@
|
|
57 |
}
|
58 |
}
|
59 |
|
60 |
-
$: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
}
|
64 |
let disabled = false;
|
65 |
async function toggleLcmLive() {
|
66 |
try {
|
67 |
if (!isLCMRunning) {
|
68 |
-
if (
|
69 |
-
|
70 |
-
await mediaStreamActions.start();
|
71 |
}
|
|
|
|
|
|
|
72 |
disabled = true;
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
} else {
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
}
|
80 |
-
lcmLiveActions.stop();
|
81 |
-
toggleQueueChecker(true);
|
82 |
}
|
83 |
} catch (e) {
|
84 |
-
|
|
|
85 |
disabled = false;
|
86 |
toggleQueueChecker(true);
|
87 |
}
|
88 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
</script>
|
90 |
|
91 |
<svelte:head>
|
@@ -111,6 +162,17 @@
|
|
111 |
> and run it on your own GPU.
|
112 |
</p>
|
113 |
{/if}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
</article>
|
115 |
{#if pipelineParams}
|
116 |
<article class="my-3 grid grid-cols-1 gap-3 sm:grid-cols-4">
|
@@ -127,7 +189,9 @@
|
|
127 |
</div>
|
128 |
<div class="sm:col-span-4 sm:row-start-2">
|
129 |
<Button on:click={toggleLcmLive} {disabled} classList={'text-lg my-1 p-2'}>
|
130 |
-
{#if
|
|
|
|
|
131 |
Stop
|
132 |
{:else}
|
133 |
Start
|
|
|
57 |
}
|
58 |
}
|
59 |
|
60 |
+
$: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED &&
|
61 |
+
$lcmLiveStatus !== LCMLiveStatus.ERROR;
|
62 |
+
$: isConnecting = $lcmLiveStatus === LCMLiveStatus.CONNECTING;
|
63 |
+
|
64 |
+
$: {
|
65 |
+
// Set warning messages based on lcmLiveStatus
|
66 |
+
if ($lcmLiveStatus === LCMLiveStatus.TIMEOUT) {
|
67 |
+
warningMessage = 'Session timed out. Please try again.';
|
68 |
+
} else if ($lcmLiveStatus === LCMLiveStatus.ERROR) {
|
69 |
+
warningMessage = 'Connection error occurred. Please try again.';
|
70 |
+
}
|
71 |
}
|
72 |
let disabled = false;
|
73 |
async function toggleLcmLive() {
|
74 |
try {
|
75 |
if (!isLCMRunning) {
|
76 |
+
if (isConnecting) {
|
77 |
+
return; // Don't allow multiple connection attempts
|
|
|
78 |
}
|
79 |
+
|
80 |
+
// Clear any previous warning messages
|
81 |
+
warningMessage = '';
|
82 |
disabled = true;
|
83 |
+
|
84 |
+
try {
|
85 |
+
if (isImageMode) {
|
86 |
+
await mediaStreamActions.enumerateDevices();
|
87 |
+
await mediaStreamActions.start();
|
88 |
+
}
|
89 |
+
|
90 |
+
await lcmLiveActions.start(getSreamdata);
|
91 |
+
toggleQueueChecker(false);
|
92 |
+
} finally {
|
93 |
+
// Always re-enable the button even if there was an error
|
94 |
+
disabled = false;
|
95 |
+
}
|
96 |
} else {
|
97 |
+
// Handle stopping - disable button during this process too
|
98 |
+
disabled = true;
|
99 |
+
|
100 |
+
try {
|
101 |
+
if (isImageMode) {
|
102 |
+
mediaStreamActions.stop();
|
103 |
+
}
|
104 |
+
await lcmLiveActions.stop();
|
105 |
+
toggleQueueChecker(true);
|
106 |
+
} finally {
|
107 |
+
disabled = false;
|
108 |
}
|
|
|
|
|
109 |
}
|
110 |
} catch (e) {
|
111 |
+
console.error('Error in toggleLcmLive:', e);
|
112 |
+
warningMessage = e instanceof Error ? e.message : 'An unknown error occurred';
|
113 |
disabled = false;
|
114 |
toggleQueueChecker(true);
|
115 |
}
|
116 |
}
|
117 |
+
|
118 |
+
// Reconnect function for automatic reconnection
|
119 |
+
async function reconnect() {
|
120 |
+
try {
|
121 |
+
disabled = true;
|
122 |
+
warningMessage = 'Reconnecting...';
|
123 |
+
|
124 |
+
if (isImageMode) {
|
125 |
+
await mediaStreamActions.stop();
|
126 |
+
await mediaStreamActions.enumerateDevices();
|
127 |
+
await mediaStreamActions.start();
|
128 |
+
}
|
129 |
+
|
130 |
+
await lcmLiveActions.reconnect(getSreamdata);
|
131 |
+
warningMessage = '';
|
132 |
+
toggleQueueChecker(false);
|
133 |
+
} catch (e) {
|
134 |
+
warningMessage = e instanceof Error ? e.message : 'Reconnection failed';
|
135 |
+
toggleQueueChecker(true);
|
136 |
+
} finally {
|
137 |
+
disabled = false;
|
138 |
+
}
|
139 |
+
}
|
140 |
</script>
|
141 |
|
142 |
<svelte:head>
|
|
|
162 |
> and run it on your own GPU.
|
163 |
</p>
|
164 |
{/if}
|
165 |
+
|
166 |
+
{#if $lcmLiveStatus === LCMLiveStatus.ERROR}
|
167 |
+
<p class="text-sm mt-2">
|
168 |
+
<button
|
169 |
+
class="text-blue-500 underline hover:no-underline"
|
170 |
+
on:click={reconnect}
|
171 |
+
disabled={disabled}>
|
172 |
+
Try reconnecting
|
173 |
+
</button>
|
174 |
+
</p>
|
175 |
+
{/if}
|
176 |
</article>
|
177 |
{#if pipelineParams}
|
178 |
<article class="my-3 grid grid-cols-1 gap-3 sm:grid-cols-4">
|
|
|
189 |
</div>
|
190 |
<div class="sm:col-span-4 sm:row-start-2">
|
191 |
<Button on:click={toggleLcmLive} {disabled} classList={'text-lg my-1 p-2'}>
|
192 |
+
{#if isConnecting}
|
193 |
+
Connecting...
|
194 |
+
{:else if isLCMRunning}
|
195 |
Stop
|
196 |
{:else}
|
197 |
Start
|
server/config.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
-
from
|
2 |
import argparse
|
3 |
import os
|
|
|
4 |
|
5 |
|
6 |
-
class Args(
|
7 |
host: str
|
8 |
port: int
|
9 |
reload: bool
|
@@ -13,19 +14,30 @@ class Args(NamedTuple):
|
|
13 |
torch_compile: bool
|
14 |
taesd: bool
|
15 |
pipeline: str
|
16 |
-
ssl_certfile: str
|
17 |
-
ssl_keyfile: str
|
18 |
sfast: bool
|
19 |
onediff: bool = False
|
20 |
compel: bool = False
|
21 |
debug: bool = False
|
22 |
|
23 |
-
def pretty_print(self):
|
24 |
print("\n")
|
25 |
-
for field, value in self.
|
26 |
print(f"{field}: {value}")
|
27 |
print("\n")
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
|
31 |
TIMEOUT = float(os.environ.get("TIMEOUT", 0))
|
@@ -113,5 +125,5 @@ parser.add_argument(
|
|
113 |
)
|
114 |
parser.set_defaults(taesd=USE_TAESD)
|
115 |
|
116 |
-
config = Args(
|
117 |
config.pretty_print()
|
|
|
1 |
+
from pydantic import BaseModel, field_validator
|
2 |
import argparse
|
3 |
import os
|
4 |
+
from typing import Annotated
|
5 |
|
6 |
|
7 |
+
class Args(BaseModel):
|
8 |
host: str
|
9 |
port: int
|
10 |
reload: bool
|
|
|
14 |
torch_compile: bool
|
15 |
taesd: bool
|
16 |
pipeline: str
|
17 |
+
ssl_certfile: str | None
|
18 |
+
ssl_keyfile: str | None
|
19 |
sfast: bool
|
20 |
onediff: bool = False
|
21 |
compel: bool = False
|
22 |
debug: bool = False
|
23 |
|
24 |
+
def pretty_print(self) -> None:
|
25 |
print("\n")
|
26 |
+
for field, value in self.model_dump().items():
|
27 |
print(f"{field}: {value}")
|
28 |
print("\n")
|
29 |
|
30 |
+
@field_validator("ssl_keyfile")
|
31 |
+
@classmethod
|
32 |
+
def validate_ssl_keyfile(cls, v: str | None, info) -> str | None:
|
33 |
+
"""Validate that if ssl_certfile is provided, ssl_keyfile is also provided."""
|
34 |
+
ssl_certfile = info.data.get("ssl_certfile")
|
35 |
+
if ssl_certfile and not v:
|
36 |
+
raise ValueError(
|
37 |
+
"If ssl_certfile is provided, ssl_keyfile must also be provided"
|
38 |
+
)
|
39 |
+
return v
|
40 |
+
|
41 |
|
42 |
MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
|
43 |
TIMEOUT = float(os.environ.get("TIMEOUT", 0))
|
|
|
125 |
)
|
126 |
parser.set_defaults(taesd=USE_TAESD)
|
127 |
|
128 |
+
config = Args.model_validate(vars(parser.parse_args()))
|
129 |
config.pretty_print()
|
server/connection_manager.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from uuid import UUID
|
2 |
import asyncio
|
3 |
from fastapi import WebSocket
|
|
|
4 |
from starlette.websockets import WebSocketState
|
5 |
import logging
|
6 |
from typing import Any
|
@@ -73,44 +74,157 @@ class ConnectionManager:
|
|
73 |
def get_user_count(self) -> int:
|
74 |
return len(self.active_connections)
|
75 |
|
76 |
-
def get_websocket(self, user_id: UUID) -> WebSocket:
|
77 |
user_session = self.active_connections.get(user_id)
|
78 |
if user_session:
|
79 |
websocket = user_session["websocket"]
|
80 |
-
|
|
|
|
|
|
|
81 |
return user_session["websocket"]
|
82 |
return None
|
83 |
|
84 |
async def disconnect(self, user_id: UUID):
|
85 |
-
|
86 |
-
if
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
self.delete_user(user_id)
|
89 |
|
90 |
async def send_json(self, user_id: UUID, data: dict):
|
91 |
try:
|
92 |
websocket = self.get_websocket(user_id)
|
93 |
if websocket:
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
except Exception as e:
|
96 |
logging.error(f"Error: Send json: {e}")
|
|
|
|
|
|
|
97 |
|
98 |
async def receive_json(self, user_id: UUID) -> dict | None:
|
99 |
try:
|
100 |
websocket = self.get_websocket(user_id)
|
101 |
if websocket:
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
return None
|
104 |
except Exception as e:
|
105 |
logging.error(f"Error: Receive json: {e}")
|
|
|
|
|
|
|
106 |
return None
|
107 |
|
108 |
async def receive_bytes(self, user_id: UUID) -> bytes | None:
|
109 |
try:
|
110 |
websocket = self.get_websocket(user_id)
|
111 |
if websocket:
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
return None
|
114 |
except Exception as e:
|
115 |
logging.error(f"Error: Receive bytes: {e}")
|
|
|
|
|
|
|
116 |
return None
|
|
|
1 |
from uuid import UUID
|
2 |
import asyncio
|
3 |
from fastapi import WebSocket
|
4 |
+
from fastapi.websockets import WebSocketDisconnect
|
5 |
from starlette.websockets import WebSocketState
|
6 |
import logging
|
7 |
from typing import Any
|
|
|
74 |
def get_user_count(self) -> int:
|
75 |
return len(self.active_connections)
|
76 |
|
77 |
+
def get_websocket(self, user_id: UUID) -> WebSocket | None:
|
78 |
user_session = self.active_connections.get(user_id)
|
79 |
if user_session:
|
80 |
websocket = user_session["websocket"]
|
81 |
+
# Both client_state and application_state should be checked
|
82 |
+
# to ensure the websocket is fully connected and not closing
|
83 |
+
if (websocket.client_state == WebSocketState.CONNECTED and
|
84 |
+
websocket.application_state == WebSocketState.CONNECTED):
|
85 |
return user_session["websocket"]
|
86 |
return None
|
87 |
|
88 |
async def disconnect(self, user_id: UUID):
|
89 |
+
# First check if user is in active connections
|
90 |
+
if user_id not in self.active_connections:
|
91 |
+
return
|
92 |
+
|
93 |
+
# Get the websocket directly from active_connections to avoid get_websocket validation
|
94 |
+
user_session = self.active_connections.get(user_id)
|
95 |
+
if user_session and "websocket" in user_session:
|
96 |
+
websocket = user_session["websocket"]
|
97 |
+
try:
|
98 |
+
# Only attempt close if not already closed
|
99 |
+
if (websocket.client_state != WebSocketState.DISCONNECTED and
|
100 |
+
websocket.application_state != WebSocketState.DISCONNECTED):
|
101 |
+
await websocket.close()
|
102 |
+
except Exception as e:
|
103 |
+
logging.error(f"Error closing websocket for {user_id}: {e}")
|
104 |
+
|
105 |
+
# Always delete the user to ensure cleanup
|
106 |
self.delete_user(user_id)
|
107 |
|
108 |
async def send_json(self, user_id: UUID, data: dict):
|
109 |
try:
|
110 |
websocket = self.get_websocket(user_id)
|
111 |
if websocket:
|
112 |
+
try:
|
113 |
+
await websocket.send_json(data)
|
114 |
+
except RuntimeError as e:
|
115 |
+
error_msg = str(e)
|
116 |
+
if any(err in error_msg for err in [
|
117 |
+
"WebSocket is not connected",
|
118 |
+
"Cannot call \"send\" once a close message has been sent",
|
119 |
+
"Cannot call \"receive\" once a close message has been sent",
|
120 |
+
"WebSocket is disconnected"]):
|
121 |
+
# The websocket was disconnected or is closing
|
122 |
+
logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
|
123 |
+
await self.disconnect(user_id)
|
124 |
+
else:
|
125 |
+
logging.error(f"Runtime error in send_json: {e}")
|
126 |
+
except WebSocketDisconnect as disconnect_error:
|
127 |
+
# Handle websocket disconnection event
|
128 |
+
code = disconnect_error.code
|
129 |
+
if code == 1006: # ABNORMAL_CLOSURE
|
130 |
+
logging.info(f"WebSocket abnormally closed for user {user_id} during send: Connection was closed without a proper close handshake")
|
131 |
+
else:
|
132 |
+
logging.info(f"WebSocket disconnected for user {user_id} with code {code} during send: {disconnect_error.reason}")
|
133 |
+
|
134 |
+
# Always disconnect the user
|
135 |
+
if user_id in self.active_connections:
|
136 |
+
await self.disconnect(user_id)
|
137 |
except Exception as e:
|
138 |
logging.error(f"Error: Send json: {e}")
|
139 |
+
# If any send fails, ensure the user gets removed to prevent further errors
|
140 |
+
if user_id in self.active_connections:
|
141 |
+
await self.disconnect(user_id)
|
142 |
|
143 |
async def receive_json(self, user_id: UUID) -> dict | None:
|
144 |
try:
|
145 |
websocket = self.get_websocket(user_id)
|
146 |
if websocket:
|
147 |
+
try:
|
148 |
+
# Receive the raw message and handle JSON parsing manually for better error handling
|
149 |
+
try:
|
150 |
+
data = await websocket.receive_json()
|
151 |
+
# Verify it's a dictionary
|
152 |
+
if not isinstance(data, dict):
|
153 |
+
logging.error(f"Expected dict but received {type(data)} from user {user_id}: {data}")
|
154 |
+
return None
|
155 |
+
return data
|
156 |
+
except ValueError as json_err:
|
157 |
+
# Specific handling for JSON parsing errors
|
158 |
+
logging.error(f"JSON parsing error for user {user_id}: {json_err}")
|
159 |
+
return None
|
160 |
+
except RuntimeError as e:
|
161 |
+
error_msg = str(e)
|
162 |
+
if any(err in error_msg for err in [
|
163 |
+
"WebSocket is not connected",
|
164 |
+
"Cannot call \"send\" once a close message has been sent",
|
165 |
+
"Cannot call \"receive\" once a close message has been sent",
|
166 |
+
"WebSocket is disconnected"]):
|
167 |
+
# The websocket was disconnected or closing
|
168 |
+
logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
|
169 |
+
await self.disconnect(user_id)
|
170 |
+
else:
|
171 |
+
logging.error(f"Runtime error in receive_json: {e}")
|
172 |
+
return None
|
173 |
+
return None
|
174 |
+
except WebSocketDisconnect as disconnect_error:
|
175 |
+
# Handle websocket disconnection event (this is a clean, expected path)
|
176 |
+
code = disconnect_error.code
|
177 |
+
if code == 1006: # ABNORMAL_CLOSURE
|
178 |
+
logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
|
179 |
+
else:
|
180 |
+
logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
|
181 |
+
|
182 |
+
# Always disconnect the user
|
183 |
+
if user_id in self.active_connections:
|
184 |
+
await self.disconnect(user_id)
|
185 |
return None
|
186 |
except Exception as e:
|
187 |
logging.error(f"Error: Receive json: {e}")
|
188 |
+
# Ensure disconnection on any exception
|
189 |
+
if user_id in self.active_connections:
|
190 |
+
await self.disconnect(user_id)
|
191 |
return None
|
192 |
|
193 |
async def receive_bytes(self, user_id: UUID) -> bytes | None:
|
194 |
try:
|
195 |
websocket = self.get_websocket(user_id)
|
196 |
if websocket:
|
197 |
+
try:
|
198 |
+
return await websocket.receive_bytes()
|
199 |
+
except RuntimeError as e:
|
200 |
+
error_msg = str(e)
|
201 |
+
if any(err in error_msg for err in [
|
202 |
+
"WebSocket is not connected",
|
203 |
+
"Cannot call \"send\" once a close message has been sent",
|
204 |
+
"Cannot call \"receive\" once a close message has been sent",
|
205 |
+
"WebSocket is disconnected"]):
|
206 |
+
# The websocket was disconnected or closing
|
207 |
+
logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
|
208 |
+
await self.disconnect(user_id)
|
209 |
+
else:
|
210 |
+
logging.error(f"Runtime error in receive_bytes: {e}")
|
211 |
+
return None
|
212 |
+
return None
|
213 |
+
except WebSocketDisconnect as disconnect_error:
|
214 |
+
# Handle websocket disconnection event (this is a clean, expected path)
|
215 |
+
code = disconnect_error.code
|
216 |
+
if code == 1006: # ABNORMAL_CLOSURE
|
217 |
+
logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
|
218 |
+
else:
|
219 |
+
logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
|
220 |
+
|
221 |
+
# Always disconnect the user
|
222 |
+
if user_id in self.active_connections:
|
223 |
+
await self.disconnect(user_id)
|
224 |
return None
|
225 |
except Exception as e:
|
226 |
logging.error(f"Error: Receive bytes: {e}")
|
227 |
+
# Ensure disconnection on any exception
|
228 |
+
if user_id in self.active_connections:
|
229 |
+
await self.disconnect(user_id)
|
230 |
return None
|
server/main.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
from fastapi import FastAPI, WebSocket, HTTPException
|
|
|
2 |
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
from fastapi.staticfiles import StaticFiles
|
@@ -9,37 +10,38 @@ from PIL import Image
|
|
9 |
import logging
|
10 |
from config import config, Args
|
11 |
from connection_manager import ConnectionManager, ServerFullException
|
12 |
-
import uuid
|
13 |
from uuid import UUID
|
14 |
import time
|
15 |
-
from typing import Any, Protocol, runtime_checkable
|
16 |
from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class, ParamsModel
|
17 |
from device import device, torch_dtype
|
18 |
import asyncio
|
19 |
import os
|
20 |
import time
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
@runtime_checkable
|
25 |
class BasePipeline(Protocol):
|
26 |
class Info:
|
27 |
@classmethod
|
28 |
-
def schema(cls) -> dict[str, Any]:
|
29 |
-
|
30 |
page_content: str | None
|
31 |
input_mode: str
|
32 |
-
|
33 |
class InputParams(ParamsModel):
|
34 |
@classmethod
|
35 |
-
def schema(cls) -> dict[str, Any]:
|
36 |
-
|
37 |
-
|
38 |
-
def dict(self) -> dict[str, Any]:
|
39 |
-
...
|
40 |
-
|
41 |
-
def predict(self, params: ParamsModel) -> Image.Image | None:
|
42 |
-
...
|
43 |
|
44 |
|
45 |
THROTTLE = 1.0 / 120
|
@@ -74,7 +76,22 @@ class App:
|
|
74 |
await handle_websocket_data(user_id)
|
75 |
except ServerFullException as e:
|
76 |
logging.error(f"Server Full: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
finally:
|
|
|
78 |
await self.conn_manager.disconnect(user_id)
|
79 |
logging.info(f"User disconnected: {user_id}")
|
80 |
|
@@ -99,34 +116,73 @@ class App:
|
|
99 |
return
|
100 |
data = await self.conn_manager.receive_json(user_id)
|
101 |
if data is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
continue
|
103 |
-
|
104 |
if data["status"] == "next_frame":
|
105 |
info = self.pipeline.Info()
|
106 |
params_data = await self.conn_manager.receive_json(user_id)
|
107 |
if params_data is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
continue
|
109 |
-
|
110 |
params = self.pipeline.InputParams.model_validate(params_data)
|
111 |
-
|
112 |
if info.input_mode == "image":
|
113 |
image_data = await self.conn_manager.receive_bytes(user_id)
|
114 |
-
if image_data is None
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
await self.conn_manager.send_json(
|
116 |
user_id, {"status": "send_frame"}
|
117 |
)
|
118 |
continue
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
124 |
|
125 |
await self.conn_manager.update_data(user_id, params)
|
126 |
await self.conn_manager.send_json(user_id, {"status": "wait"})
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
except Exception as e:
|
129 |
-
logging.error(f"Websocket Error: {e}, {user_id}
|
130 |
await self.conn_manager.disconnect(user_id)
|
131 |
|
132 |
@self.app.get("/api/queue")
|
@@ -137,25 +193,58 @@ class App:
|
|
137 |
@self.app.get("/api/stream/{user_id}")
|
138 |
async def stream(user_id: UUID, request: Request) -> StreamingResponse:
|
139 |
try:
|
140 |
-
|
|
|
141 |
last_params: ParamsModel | None = None
|
142 |
while True:
|
|
|
|
|
|
|
|
|
|
|
143 |
last_time = time.time()
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
params = await self.conn_manager.get_latest_data(user_id)
|
148 |
-
|
149 |
-
if
|
150 |
-
(last_params is not None and
|
151 |
-
params.model_dump() == last_params.model_dump())):
|
152 |
await asyncio.sleep(THROTTLE)
|
153 |
continue
|
154 |
-
|
155 |
-
last_params = params
|
156 |
-
image = self.pipeline.predict(params)
|
157 |
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
image, has_nsfw_concept = self.safety_checker(image)
|
160 |
if has_nsfw_concept:
|
161 |
image = None
|
@@ -175,9 +264,34 @@ class App:
|
|
175 |
media_type="multipart/x-mixed-replace;boundary=frame",
|
176 |
headers={"Cache-Control": "no-cache"},
|
177 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
except Exception as e:
|
179 |
-
logging.error(f"Streaming Error: {e}, {user_id}
|
180 |
-
|
|
|
|
|
181 |
|
182 |
# route to setup frontend
|
183 |
@self.app.get("/api/settings")
|
@@ -185,7 +299,7 @@ class App:
|
|
185 |
info_schema = self.pipeline.Info.schema()
|
186 |
info = self.pipeline.Info()
|
187 |
page_content = ""
|
188 |
-
if hasattr(info,
|
189 |
page_content = markdown2.markdown(info.page_content)
|
190 |
|
191 |
input_params = self.pipeline.InputParams.schema()
|
@@ -234,7 +348,7 @@ if __name__ == "__main__":
|
|
234 |
# app = create_app(config) # Create the app once
|
235 |
|
236 |
uvicorn.run(
|
237 |
-
app,
|
238 |
host=config.host,
|
239 |
port=config.port,
|
240 |
reload=config.reload,
|
|
|
1 |
+
from fastapi import FastAPI, WebSocket, HTTPException
|
2 |
+
from fastapi.websockets import WebSocketDisconnect
|
3 |
from fastapi.responses import StreamingResponse, JSONResponse
|
4 |
from fastapi.middleware.cors import CORSMiddleware
|
5 |
from fastapi.staticfiles import StaticFiles
|
|
|
10 |
import logging
|
11 |
from config import config, Args
|
12 |
from connection_manager import ConnectionManager, ServerFullException
|
|
|
13 |
from uuid import UUID
|
14 |
import time
|
15 |
+
from typing import Any, Protocol, runtime_checkable, AsyncGenerator
|
16 |
from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class, ParamsModel
|
17 |
from device import device, torch_dtype
|
18 |
import asyncio
|
19 |
import os
|
20 |
import time
|
21 |
+
|
22 |
+
# Common WebSocket error messages that indicate disconnection
|
23 |
+
ERROR_MESSAGES = [
|
24 |
+
"WebSocket is not connected",
|
25 |
+
'Cannot call "send" once a close message has been sent',
|
26 |
+
'Cannot call "receive" once a close message has been sent',
|
27 |
+
"WebSocket is disconnected",
|
28 |
+
]
|
29 |
+
|
30 |
|
31 |
@runtime_checkable
|
32 |
class BasePipeline(Protocol):
|
33 |
class Info:
|
34 |
@classmethod
|
35 |
+
def schema(cls) -> dict[str, Any]: ...
|
36 |
+
|
37 |
page_content: str | None
|
38 |
input_mode: str
|
39 |
+
|
40 |
class InputParams(ParamsModel):
|
41 |
@classmethod
|
42 |
+
def schema(cls) -> dict[str, Any]: ...
|
43 |
+
|
44 |
+
def predict(self, params: ParamsModel) -> Image.Image | None: ...
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
THROTTLE = 1.0 / 120
|
|
|
76 |
await handle_websocket_data(user_id)
|
77 |
except ServerFullException as e:
|
78 |
logging.error(f"Server Full: {e}")
|
79 |
+
except WebSocketDisconnect as disconnect_error:
|
80 |
+
# Handle websocket disconnection event
|
81 |
+
code = disconnect_error.code
|
82 |
+
if code == 1006: # ABNORMAL_CLOSURE
|
83 |
+
logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
|
84 |
+
else:
|
85 |
+
logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
|
86 |
+
except RuntimeError as e:
|
87 |
+
if any(err in str(e) for err in ERROR_MESSAGES):
|
88 |
+
logging.info(f"WebSocket disconnected for user {user_id}: {e}")
|
89 |
+
else:
|
90 |
+
logging.error(f"Runtime error in websocket endpoint: {e}")
|
91 |
+
except Exception as e:
|
92 |
+
logging.error(f"Unexpected error in websocket endpoint: {e}")
|
93 |
finally:
|
94 |
+
# Always ensure we disconnect the user
|
95 |
await self.conn_manager.disconnect(user_id)
|
96 |
logging.info(f"User disconnected: {user_id}")
|
97 |
|
|
|
116 |
return
|
117 |
data = await self.conn_manager.receive_json(user_id)
|
118 |
if data is None:
|
119 |
+
# Check if the user is still connected - they might have disconnected
|
120 |
+
if not self.conn_manager.check_user(user_id):
|
121 |
+
logging.info(
|
122 |
+
f"User {user_id} disconnected, exiting handle_websocket_data loop"
|
123 |
+
)
|
124 |
+
return
|
125 |
+
continue
|
126 |
+
|
127 |
+
# Validate that data is a dictionary and has a status field
|
128 |
+
if not isinstance(data, dict) or "status" not in data:
|
129 |
+
logging.error(
|
130 |
+
f"Invalid data format received from user {user_id}: {data}"
|
131 |
+
)
|
132 |
continue
|
133 |
+
|
134 |
if data["status"] == "next_frame":
|
135 |
info = self.pipeline.Info()
|
136 |
params_data = await self.conn_manager.receive_json(user_id)
|
137 |
if params_data is None:
|
138 |
+
# Check if the user is still connected
|
139 |
+
if not self.conn_manager.check_user(user_id):
|
140 |
+
logging.info(
|
141 |
+
f"User {user_id} disconnected during params reception"
|
142 |
+
)
|
143 |
+
return
|
144 |
continue
|
145 |
+
|
146 |
params = self.pipeline.InputParams.model_validate(params_data)
|
147 |
+
|
148 |
if info.input_mode == "image":
|
149 |
image_data = await self.conn_manager.receive_bytes(user_id)
|
150 |
+
if image_data is None:
|
151 |
+
# Check if the user is still connected
|
152 |
+
if not self.conn_manager.check_user(user_id):
|
153 |
+
logging.info(
|
154 |
+
f"User {user_id} disconnected during image reception"
|
155 |
+
)
|
156 |
+
return
|
157 |
await self.conn_manager.send_json(
|
158 |
user_id, {"status": "send_frame"}
|
159 |
)
|
160 |
continue
|
161 |
+
if len(image_data) == 0:
|
162 |
+
await self.conn_manager.send_json(
|
163 |
+
user_id, {"status": "send_frame"}
|
164 |
+
)
|
165 |
+
continue
|
166 |
+
|
167 |
+
# Add the image directly to the model using setattr
|
168 |
+
# This works because we've configured the ParamsModel to allow extra fields
|
169 |
+
setattr(params, "image", bytes_to_pil(image_data))
|
170 |
|
171 |
await self.conn_manager.update_data(user_id, params)
|
172 |
await self.conn_manager.send_json(user_id, {"status": "wait"})
|
173 |
|
174 |
+
except RuntimeError as e:
|
175 |
+
error_msg = str(e)
|
176 |
+
if any(err in error_msg for err in ERROR_MESSAGES):
|
177 |
+
logging.info(
|
178 |
+
f"WebSocket disconnected for user {user_id}: {error_msg}"
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
logging.error(f"Websocket Runtime Error: {e}, {user_id}")
|
182 |
+
# Ensure disconnect is called
|
183 |
+
await self.conn_manager.disconnect(user_id)
|
184 |
except Exception as e:
|
185 |
+
logging.error(f"Websocket Error: {e}, {user_id}")
|
186 |
await self.conn_manager.disconnect(user_id)
|
187 |
|
188 |
@self.app.get("/api/queue")
|
|
|
193 |
@self.app.get("/api/stream/{user_id}")
|
194 |
async def stream(user_id: UUID, request: Request) -> StreamingResponse:
|
195 |
try:
|
196 |
+
|
197 |
+
async def generate() -> AsyncGenerator[bytes, None]:
|
198 |
last_params: ParamsModel | None = None
|
199 |
while True:
|
200 |
+
# Check if the user is still connected
|
201 |
+
if not self.conn_manager.check_user(user_id):
|
202 |
+
logging.info(f"User {user_id} disconnected from stream")
|
203 |
+
break
|
204 |
+
|
205 |
last_time = time.time()
|
206 |
+
try:
|
207 |
+
await self.conn_manager.send_json(
|
208 |
+
user_id, {"status": "send_frame"}
|
209 |
+
)
|
210 |
+
except Exception as e:
|
211 |
+
logging.error(f"Error sending to websocket in stream: {e}")
|
212 |
+
# User might have disconnected
|
213 |
+
if not self.conn_manager.check_user(user_id):
|
214 |
+
logging.info(f"User {user_id} disconnected from stream")
|
215 |
+
break
|
216 |
+
await asyncio.sleep(THROTTLE)
|
217 |
+
continue
|
218 |
+
|
219 |
params = await self.conn_manager.get_latest_data(user_id)
|
220 |
+
|
221 |
+
if params is None:
|
|
|
|
|
222 |
await asyncio.sleep(THROTTLE)
|
223 |
continue
|
|
|
|
|
|
|
224 |
|
225 |
+
try:
|
226 |
+
# Check if the params haven't changed since last time
|
227 |
+
if (
|
228 |
+
last_params is not None
|
229 |
+
and params.model_dump() == last_params.model_dump()
|
230 |
+
):
|
231 |
+
await asyncio.sleep(THROTTLE)
|
232 |
+
continue
|
233 |
+
|
234 |
+
last_params = params
|
235 |
+
image = self.pipeline.predict(params)
|
236 |
+
except Exception as e:
|
237 |
+
logging.error(
|
238 |
+
f"Error processing params for user {user_id}: {e}"
|
239 |
+
)
|
240 |
+
await asyncio.sleep(THROTTLE)
|
241 |
+
continue
|
242 |
+
|
243 |
+
if (
|
244 |
+
self.args.safety_checker
|
245 |
+
and self.safety_checker is not None
|
246 |
+
and image is not None
|
247 |
+
):
|
248 |
image, has_nsfw_concept = self.safety_checker(image)
|
249 |
if has_nsfw_concept:
|
250 |
image = None
|
|
|
264 |
media_type="multipart/x-mixed-replace;boundary=frame",
|
265 |
headers={"Cache-Control": "no-cache"},
|
266 |
)
|
267 |
+
except WebSocketDisconnect as disconnect_error:
|
268 |
+
# Handle websocket disconnection event
|
269 |
+
code = disconnect_error.code
|
270 |
+
if code == 1006: # ABNORMAL_CLOSURE
|
271 |
+
logging.info(f"WebSocket abnormally closed during streaming for user {user_id}: Connection was closed without a proper close handshake")
|
272 |
+
else:
|
273 |
+
logging.info(f"WebSocket disconnected during streaming for user {user_id} with code {code}: {disconnect_error.reason}")
|
274 |
+
|
275 |
+
# Clean disconnection without error response
|
276 |
+
await self.conn_manager.disconnect(user_id)
|
277 |
+
raise HTTPException(status_code=204, detail="Connection closed")
|
278 |
+
except RuntimeError as e:
|
279 |
+
error_msg = str(e)
|
280 |
+
if any(err in error_msg for err in ERROR_MESSAGES):
|
281 |
+
logging.info(
|
282 |
+
f"WebSocket disconnected during streaming for user {user_id}: {error_msg}"
|
283 |
+
)
|
284 |
+
# Clean disconnection without error response
|
285 |
+
await self.conn_manager.disconnect(user_id)
|
286 |
+
raise HTTPException(status_code=204, detail="Connection closed")
|
287 |
+
else:
|
288 |
+
logging.error(f"Streaming Runtime Error: {e}, {user_id}")
|
289 |
+
raise HTTPException(status_code=500, detail="Streaming error")
|
290 |
except Exception as e:
|
291 |
+
logging.error(f"Streaming Error: {e}, {user_id}")
|
292 |
+
# Always ensure we disconnect the user on error
|
293 |
+
await self.conn_manager.disconnect(user_id)
|
294 |
+
raise HTTPException(status_code=500, detail="Streaming error")
|
295 |
|
296 |
# route to setup frontend
|
297 |
@self.app.get("/api/settings")
|
|
|
299 |
info_schema = self.pipeline.Info.schema()
|
300 |
info = self.pipeline.Info()
|
301 |
page_content = ""
|
302 |
+
if hasattr(info, "page_content") and info.page_content:
|
303 |
page_content = markdown2.markdown(info.page_content)
|
304 |
|
305 |
input_params = self.pipeline.InputParams.schema()
|
|
|
348 |
# app = create_app(config) # Create the app once
|
349 |
|
350 |
uvicorn.run(
|
351 |
+
"main:app",
|
352 |
host=config.host,
|
353 |
port=config.port,
|
354 |
reload=config.reload,
|
server/pipelines/img2img.py
CHANGED
@@ -14,6 +14,7 @@ import psutil
|
|
14 |
from config import Args
|
15 |
from pydantic import BaseModel, Field
|
16 |
from PIL import Image
|
|
|
17 |
import math
|
18 |
|
19 |
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
@@ -54,7 +55,7 @@ class Pipeline:
|
|
54 |
input_mode: str = "image"
|
55 |
page_content: str = page_content
|
56 |
|
57 |
-
class InputParams(
|
58 |
prompt: str = Field(
|
59 |
default_prompt,
|
60 |
title="Prompt",
|
|
|
14 |
from config import Args
|
15 |
from pydantic import BaseModel, Field
|
16 |
from PIL import Image
|
17 |
+
from util import ParamsModel
|
18 |
import math
|
19 |
|
20 |
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
|
|
55 |
input_mode: str = "image"
|
56 |
page_content: str = page_content
|
57 |
|
58 |
+
class InputParams(ParamsModel):
|
59 |
prompt: str = Field(
|
60 |
default_prompt,
|
61 |
title="Prompt",
|
server/pipelines/img2imgFlux.py
CHANGED
@@ -28,6 +28,7 @@ from config import Args
|
|
28 |
from pydantic import BaseModel, Field
|
29 |
from PIL import Image
|
30 |
from pathlib import Path
|
|
|
31 |
import math
|
32 |
import gc
|
33 |
|
@@ -61,7 +62,7 @@ class Pipeline:
|
|
61 |
input_mode: str = "image"
|
62 |
page_content: str = page_content
|
63 |
|
64 |
-
class InputParams(
|
65 |
prompt: str = Field(
|
66 |
default_prompt,
|
67 |
title="Prompt",
|
|
|
28 |
from pydantic import BaseModel, Field
|
29 |
from PIL import Image
|
30 |
from pathlib import Path
|
31 |
+
from util import ParamsModel
|
32 |
import math
|
33 |
import gc
|
34 |
|
|
|
62 |
input_mode: str = "image"
|
63 |
page_content: str = page_content
|
64 |
|
65 |
+
class InputParams(ParamsModel):
|
66 |
prompt: str = Field(
|
67 |
default_prompt,
|
68 |
title="Prompt",
|
server/pipelines/img2imgSDTurbo.py
CHANGED
@@ -9,10 +9,10 @@ try:
|
|
9 |
except:
|
10 |
pass
|
11 |
|
12 |
-
import psutil
|
13 |
from config import Args
|
14 |
from pydantic import BaseModel, Field
|
15 |
from PIL import Image
|
|
|
16 |
import math
|
17 |
|
18 |
|
@@ -55,7 +55,7 @@ class Pipeline:
|
|
55 |
input_mode: str = "image"
|
56 |
page_content: str = page_content
|
57 |
|
58 |
-
class InputParams(
|
59 |
prompt: str = Field(
|
60 |
default_prompt,
|
61 |
title="Prompt",
|
|
|
9 |
except:
|
10 |
pass
|
11 |
|
|
|
12 |
from config import Args
|
13 |
from pydantic import BaseModel, Field
|
14 |
from PIL import Image
|
15 |
+
from util import ParamsModel
|
16 |
import math
|
17 |
|
18 |
|
|
|
55 |
input_mode: str = "image"
|
56 |
page_content: str = page_content
|
57 |
|
58 |
+
class InputParams(ParamsModel):
|
59 |
prompt: str = Field(
|
60 |
default_prompt,
|
61 |
title="Prompt",
|
server/pipelines/img2imgSDXL-Lightning.py
CHANGED
@@ -18,6 +18,7 @@ from huggingface_hub import hf_hub_download
|
|
18 |
from config import Args
|
19 |
from pydantic import BaseModel, Field
|
20 |
from PIL import Image
|
|
|
21 |
import math
|
22 |
|
23 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
@@ -62,7 +63,7 @@ class Pipeline:
|
|
62 |
input_mode: str = "image"
|
63 |
page_content: str = page_content
|
64 |
|
65 |
-
class InputParams(
|
66 |
prompt: str = Field(
|
67 |
default_prompt,
|
68 |
title="Prompt",
|
|
|
18 |
from config import Args
|
19 |
from pydantic import BaseModel, Field
|
20 |
from PIL import Image
|
21 |
+
from util import ParamsModel
|
22 |
import math
|
23 |
|
24 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
|
|
63 |
input_mode: str = "image"
|
64 |
page_content: str = page_content
|
65 |
|
66 |
+
class InputParams(ParamsModel):
|
67 |
prompt: str = Field(
|
68 |
default_prompt,
|
69 |
title="Prompt",
|
server/pipelines/img2imgSDXLTurbo.py
CHANGED
@@ -14,6 +14,7 @@ import psutil
|
|
14 |
from config import Args
|
15 |
from pydantic import BaseModel, Field
|
16 |
from PIL import Image
|
|
|
17 |
import math
|
18 |
|
19 |
base_model = "stabilityai/sdxl-turbo"
|
@@ -55,7 +56,7 @@ class Pipeline:
|
|
55 |
input_mode: str = "image"
|
56 |
page_content: str = page_content
|
57 |
|
58 |
-
class InputParams(
|
59 |
prompt: str = Field(
|
60 |
default_prompt,
|
61 |
title="Prompt",
|
|
|
14 |
from config import Args
|
15 |
from pydantic import BaseModel, Field
|
16 |
from PIL import Image
|
17 |
+
from util import ParamsModel
|
18 |
import math
|
19 |
|
20 |
base_model = "stabilityai/sdxl-turbo"
|
|
|
56 |
input_mode: str = "image"
|
57 |
page_content: str = page_content
|
58 |
|
59 |
+
class InputParams(ParamsModel):
|
60 |
prompt: str = Field(
|
61 |
default_prompt,
|
62 |
title="Prompt",
|
server/pipelines/img2imgSDXS512.py
CHANGED
@@ -11,6 +11,7 @@ import psutil
|
|
11 |
from config import Args
|
12 |
from pydantic import BaseModel, Field
|
13 |
from PIL import Image
|
|
|
14 |
import math
|
15 |
|
16 |
base_model = "IDKiro/sdxs-512-0.9"
|
@@ -51,7 +52,7 @@ class Pipeline:
|
|
51 |
input_mode: str = "image"
|
52 |
page_content: str = page_content
|
53 |
|
54 |
-
class InputParams(
|
55 |
prompt: str = Field(
|
56 |
default_prompt,
|
57 |
title="Prompt",
|
|
|
11 |
from config import Args
|
12 |
from pydantic import BaseModel, Field
|
13 |
from PIL import Image
|
14 |
+
from util import ParamsModel
|
15 |
import math
|
16 |
|
17 |
base_model = "IDKiro/sdxs-512-0.9"
|
|
|
52 |
input_mode: str = "image"
|
53 |
page_content: str = page_content
|
54 |
|
55 |
+
class InputParams(ParamsModel):
|
56 |
prompt: str = Field(
|
57 |
default_prompt,
|
58 |
title="Prompt",
|
server/util.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from importlib import import_module
|
2 |
-
from typing import Any, TypeVar
|
3 |
from PIL import Image
|
4 |
import io
|
5 |
from pydantic import BaseModel
|
@@ -11,12 +11,17 @@ TPipeline = TypeVar("TPipeline", bound=type[Any])
|
|
11 |
|
12 |
class ParamsModel(BaseModel):
|
13 |
"""Base model for pipeline parameters."""
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
15 |
@classmethod
|
16 |
-
def from_dict(cls, data: dict[str, Any]) ->
|
17 |
"""Create a model instance from dictionary data."""
|
18 |
return cls.model_validate(data)
|
19 |
-
|
20 |
def to_dict(self) -> dict[str, Any]:
|
21 |
"""Convert model to dictionary."""
|
22 |
return self.model_dump()
|
@@ -25,13 +30,13 @@ class ParamsModel(BaseModel):
|
|
25 |
def get_pipeline_class(pipeline_name: str) -> type:
|
26 |
"""
|
27 |
Dynamically imports and returns the Pipeline class from a specified module.
|
28 |
-
|
29 |
Args:
|
30 |
pipeline_name: The name of the pipeline module to import
|
31 |
-
|
32 |
Returns:
|
33 |
The Pipeline class from the specified module
|
34 |
-
|
35 |
Raises:
|
36 |
ValueError: If the module or Pipeline class isn't found
|
37 |
TypeError: If Pipeline is not a class
|
|
|
1 |
from importlib import import_module
|
2 |
+
from typing import Any, TypeVar
|
3 |
from PIL import Image
|
4 |
import io
|
5 |
from pydantic import BaseModel
|
|
|
11 |
|
12 |
class ParamsModel(BaseModel):
|
13 |
"""Base model for pipeline parameters."""
|
14 |
+
|
15 |
+
model_config = {
|
16 |
+
"arbitrary_types_allowed": True,
|
17 |
+
"extra": "allow", # Allow extra attributes for dynamic fields like 'image'
|
18 |
+
}
|
19 |
+
|
20 |
@classmethod
|
21 |
+
def from_dict(cls, data: dict[str, Any]) -> "ParamsModel":
|
22 |
"""Create a model instance from dictionary data."""
|
23 |
return cls.model_validate(data)
|
24 |
+
|
25 |
def to_dict(self) -> dict[str, Any]:
|
26 |
"""Convert model to dictionary."""
|
27 |
return self.model_dump()
|
|
|
30 |
def get_pipeline_class(pipeline_name: str) -> type:
|
31 |
"""
|
32 |
Dynamically imports and returns the Pipeline class from a specified module.
|
33 |
+
|
34 |
Args:
|
35 |
pipeline_name: The name of the pipeline module to import
|
36 |
+
|
37 |
Returns:
|
38 |
The Pipeline class from the specified module
|
39 |
+
|
40 |
Raises:
|
41 |
ValueError: If the module or Pipeline class isn't found
|
42 |
TypeError: If Pipeline is not a class
|