Spaces:
Runtime error
Runtime error
praeclarumjj3
commited on
Commit
·
bb7e9ea
1
Parent(s):
a7e7927
Update chat.py
Browse files
chat.py
CHANGED
@@ -11,9 +11,10 @@ from vcoder_llava.mm_utils import process_images, load_image_from_base64, tokeni
|
|
11 |
from vcoder_llava.constants import (
|
12 |
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
|
13 |
SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
|
14 |
-
DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN
|
15 |
)
|
16 |
from transformers import TextIteratorStreamer
|
|
|
17 |
|
18 |
class Chat:
|
19 |
def __init__(self, model_path, model_base, model_name,
|
@@ -35,7 +36,7 @@ class Chat:
|
|
35 |
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
36 |
self.is_multimodal = 'llava' in self.model_name.lower()
|
37 |
self.is_seg = "vcoder" in self.model_name.lower()
|
38 |
-
self.is_depth =
|
39 |
|
40 |
@torch.inference_mode()
|
41 |
def generate_stream(self, params):
|
@@ -167,21 +168,21 @@ class Chat:
|
|
167 |
"text": server_error_msg,
|
168 |
"error_code": 1,
|
169 |
}
|
170 |
-
yield json.dumps(ret).encode()
|
171 |
except torch.cuda.CudaError as e:
|
172 |
print("Caught torch.cuda.CudaError:", e)
|
173 |
ret = {
|
174 |
"text": server_error_msg,
|
175 |
"error_code": 1,
|
176 |
}
|
177 |
-
yield json.dumps(ret).encode()
|
178 |
except Exception as e:
|
179 |
print("Caught Unknown Error", e)
|
180 |
ret = {
|
181 |
"text": server_error_msg,
|
182 |
"error_code": 1,
|
183 |
}
|
184 |
-
yield json.dumps(ret).encode()
|
185 |
|
186 |
|
187 |
if __name__ == "__main__":
|
|
|
11 |
from vcoder_llava.constants import (
|
12 |
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
|
13 |
SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
|
14 |
+
DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN,
|
15 |
)
|
16 |
from transformers import TextIteratorStreamer
|
17 |
+
from threading import Thread
|
18 |
|
19 |
class Chat:
|
20 |
def __init__(self, model_path, model_base, model_name,
|
|
|
36 |
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
37 |
self.is_multimodal = 'llava' in self.model_name.lower()
|
38 |
self.is_seg = "vcoder" in self.model_name.lower()
|
39 |
+
self.is_depth = "ds" in self.model_name.lower()
|
40 |
|
41 |
@torch.inference_mode()
|
42 |
def generate_stream(self, params):
|
|
|
168 |
"text": server_error_msg,
|
169 |
"error_code": 1,
|
170 |
}
|
171 |
+
yield json.dumps(ret).encode() + b"\0"
|
172 |
except torch.cuda.CudaError as e:
|
173 |
print("Caught torch.cuda.CudaError:", e)
|
174 |
ret = {
|
175 |
"text": server_error_msg,
|
176 |
"error_code": 1,
|
177 |
}
|
178 |
+
yield json.dumps(ret).encode() + b"\0"
|
179 |
except Exception as e:
|
180 |
print("Caught Unknown Error", e)
|
181 |
ret = {
|
182 |
"text": server_error_msg,
|
183 |
"error_code": 1,
|
184 |
}
|
185 |
+
yield json.dumps(ret).encode() + b"\0"
|
186 |
|
187 |
|
188 |
if __name__ == "__main__":
|