praeclarumjj3 commited on
Commit
6f12dd3
·
1 Parent(s): dc7f71c

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +5 -6
chat.py CHANGED
@@ -11,10 +11,9 @@ from vcoder_llava.mm_utils import process_images, load_image_from_base64, tokeni
11
  from vcoder_llava.constants import (
12
  IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
13
  SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
14
- DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN,
15
  )
16
  from transformers import TextIteratorStreamer
17
- from threading import Thread
18
 
19
  class Chat:
20
  def __init__(self, model_path, model_base, model_name,
@@ -36,7 +35,7 @@ class Chat:
36
  model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
37
  self.is_multimodal = 'llava' in self.model_name.lower()
38
  self.is_seg = "vcoder" in self.model_name.lower()
39
- self.is_depth = "ds" in self.model_name.lower()
40
 
41
  @torch.inference_mode()
42
  def generate_stream(self, params):
@@ -168,21 +167,21 @@ class Chat:
168
  "text": server_error_msg,
169
  "error_code": 1,
170
  }
171
- yield json.dumps(ret).encode() + b"\0"
172
  except torch.cuda.CudaError as e:
173
  print("Caught torch.cuda.CudaError:", e)
174
  ret = {
175
  "text": server_error_msg,
176
  "error_code": 1,
177
  }
178
- yield json.dumps(ret).encode() + b"\0"
179
  except Exception as e:
180
  print("Caught Unknown Error", e)
181
  ret = {
182
  "text": server_error_msg,
183
  "error_code": 1,
184
  }
185
- yield json.dumps(ret).encode() + b"\0"
186
 
187
 
188
  if __name__ == "__main__":
 
11
  from vcoder_llava.constants import (
12
  IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
13
  SEG_TOKEN_INDEX, DEFAULT_SEG_TOKEN,
14
+ DEPTH_TOKEN_INDEX, DEFAULT_DEPTH_TOKEN
15
  )
16
  from transformers import TextIteratorStreamer
 
17
 
18
  class Chat:
19
  def __init__(self, model_path, model_base, model_name,
 
35
  model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
36
  self.is_multimodal = 'llava' in self.model_name.lower()
37
  self.is_seg = "vcoder" in self.model_name.lower()
38
+ self.is_depth = False
39
 
40
  @torch.inference_mode()
41
  def generate_stream(self, params):
 
167
  "text": server_error_msg,
168
  "error_code": 1,
169
  }
170
+ yield json.dumps(ret).encode()
171
  except torch.cuda.CudaError as e:
172
  print("Caught torch.cuda.CudaError:", e)
173
  ret = {
174
  "text": server_error_msg,
175
  "error_code": 1,
176
  }
177
+ yield json.dumps(ret).encode()
178
  except Exception as e:
179
  print("Caught Unknown Error", e)
180
  ret = {
181
  "text": server_error_msg,
182
  "error_code": 1,
183
  }
184
+ yield json.dumps(ret).encode()
185
 
186
 
187
  if __name__ == "__main__":