yejunliang23 commited on
Commit
07402ca
·
verified ·
1 Parent(s): 3abec55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -21,10 +21,6 @@ from huggingface_hub import hf_hub_download
21
  import numpy as np
22
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
 
24
- import torch
25
- print("CUDA available:", torch.cuda.is_available())
26
- print("CUDA device count:", torch.cuda.device_count())
27
- 1/0
28
  def _remove_image_special(text):
29
  text = text.replace('<ref>', '').replace('</ref>', '')
30
  return re.sub(r'<box>.*?(</box>|$)', '', text)
@@ -106,7 +102,7 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
106
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
107
  image_inputs, video_inputs = process_vision_info(messages)
108
  inputs = processor(text=[text], images=image_inputs,videos=video_inputs, padding=True, return_tensors='pt')
109
- inputs = inputs.to(model.device)
110
 
111
  eos_token_id = [tokenizer.eos_token_id,159858]
112
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
@@ -134,7 +130,7 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
134
 
135
  if encoding_indices is not None:
136
  print("processing mesh...")
137
- recon = vqvae.Decode(encoding_indices.to(model.device))
138
  z_s = recon[0].detach().cpu()
139
  z_s = (z_s>0)*1
140
  indices = torch.nonzero(z_s[0] == 1)
@@ -148,7 +144,7 @@ def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,t
148
  ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
149
  ss=ss.unsqueeze(0)
150
  coords = torch.argwhere(ss>0)[:, [0, 2, 3, 4]].int()
151
- coords = coords.to(model.device)
152
  try:
153
  print("processing mesh...")
154
  if len(image_lst) == 0:
 
21
  import numpy as np
22
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
 
 
 
 
 
24
  def _remove_image_special(text):
25
  text = text.replace('<ref>', '').replace('</ref>', '')
26
  return re.sub(r'<box>.*?(</box>|$)', '', text)
 
102
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
103
  image_inputs, video_inputs = process_vision_info(messages)
104
  inputs = processor(text=[text], images=image_inputs,videos=video_inputs, padding=True, return_tensors='pt')
105
+ inputs = inputs.to("cuda")
106
 
107
  eos_token_id = [tokenizer.eos_token_id,159858]
108
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
130
 
131
  if encoding_indices is not None:
132
  print("processing mesh...")
133
+ recon = vqvae.Decode(encoding_indices.to("cuda"))
134
  z_s = recon[0].detach().cpu()
135
  z_s = (z_s>0)*1
136
  indices = torch.nonzero(z_s[0] == 1)
 
144
  ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
145
  ss=ss.unsqueeze(0)
146
  coords = torch.argwhere(ss>0)[:, [0, 2, 3, 4]].int()
147
+ coords = coords.to("cuda")
148
  try:
149
  print("processing mesh...")
150
  if len(image_lst) == 0: