praeclarumjj3 commited on
Commit
03dacef
·
1 Parent(s): 016e4dd

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +6 -5
chat.py CHANGED
@@ -11,9 +11,10 @@ from vcoder_llava.mm_utils import process_images, load_image_from_base64, tokeni
11
  from vcoder_llava.constants import (
12
  IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
13
  SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
14
- DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN
15
  )
16
  from transformers import TextIteratorStreamer
 
17
 
18
  class Chat:
19
  def __init__(self, model_path, model_base, model_name,
@@ -35,7 +36,7 @@ class Chat:
35
  model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
36
  self.is_multimodal = 'llava' in self.model_name.lower()
37
  self.is_seg = "vcoder" in self.model_name.lower()
38
- self.is_depth = False
39
 
40
  @torch.inference_mode()
41
  def generate_stream(self, params):
@@ -167,21 +168,21 @@ class Chat:
167
  "text": server_error_msg,
168
  "error_code": 1,
169
  }
170
- yield json.dumps(ret).encode()
171
  except torch.cuda.CudaError as e:
172
  print("Caught torch.cuda.CudaError:", e)
173
  ret = {
174
  "text": server_error_msg,
175
  "error_code": 1,
176
  }
177
- yield json.dumps(ret).encode()
178
  except Exception as e:
179
  print("Caught Unknown Error", e)
180
  ret = {
181
  "text": server_error_msg,
182
  "error_code": 1,
183
  }
184
- yield json.dumps(ret).encode()
185
 
186
 
187
  if __name__ == "__main__":
 
11
  from vcoder_llava.constants import (
12
  IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
13
  SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
14
+ DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN,
15
  )
16
  from transformers import TextIteratorStreamer
17
+ from threading import Thread
18
 
19
  class Chat:
20
  def __init__(self, model_path, model_base, model_name,
 
36
  model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
37
  self.is_multimodal = 'llava' in self.model_name.lower()
38
  self.is_seg = "vcoder" in self.model_name.lower()
39
+ self.is_depth = "ds" in self.model_name.lower()
40
 
41
  @torch.inference_mode()
42
  def generate_stream(self, params):
 
168
  "text": server_error_msg,
169
  "error_code": 1,
170
  }
171
+ yield json.dumps(ret).encode() + b"\0"
172
  except torch.cuda.CudaError as e:
173
  print("Caught torch.cuda.CudaError:", e)
174
  ret = {
175
  "text": server_error_msg,
176
  "error_code": 1,
177
  }
178
+ yield json.dumps(ret).encode() + b"\0"
179
  except Exception as e:
180
  print("Caught Unknown Error", e)
181
  ret = {
182
  "text": server_error_msg,
183
  "error_code": 1,
184
  }
185
+ yield json.dumps(ret).encode() + b"\0"
186
 
187
 
188
  if __name__ == "__main__":