edward2021 commited on
Commit
334df79
1 Parent(s): 1ad306d

add text reponse

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. pq3d/inference.py +6 -1
app.py CHANGED
@@ -50,14 +50,14 @@ with gr.Blocks(title='PQ3D Demo') as demo:
50
 
51
  def inference_wrapper(text):
52
  scan_id = model_3d.value['orig_name'].split('.')[0]
53
- inst_id = inference(scan_id, text)
54
- return f"assets/mask/{scan_id}/{scan_id}_obj_{inst_id}.glb"
55
 
56
  gr.Interface(
57
  fn=inference_wrapper,
58
  inputs=["text"],
59
- outputs=gr.Model3D(
60
- clear_color=[0.0, 0.0, 0.0, 0.0], camera_position=(80, 100, 6), label="3D Model"),
61
  examples=[
62
  ["armchair"], ["Sofa"], ["left computer on the desk"]
63
  ],
 
50
 
51
  def inference_wrapper(text):
52
  scan_id = model_3d.value['orig_name'].split('.')[0]
53
+ inst_id, response = inference(scan_id, text)
54
+ return f"assets/mask/{scan_id}/{scan_id}_obj_{inst_id}.glb", response
55
 
56
  gr.Interface(
57
  fn=inference_wrapper,
58
  inputs=["text"],
59
+ outputs=[gr.Model3D(
60
+ clear_color=[0.0, 0.0, 0.0, 0.0], camera_position=(80, 100, 6), label="3D Model"), "text"],
61
  examples=[
62
  ["armchair"], ["Sofa"], ["left computer on the desk"]
63
  ],
pq3d/inference.py CHANGED
@@ -171,9 +171,14 @@ def inference(scan_id, text):
171
  model = Query3DUnified()
172
  load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False)
173
  data_dict = model(data_dict)
 
174
  result_id = data_dict['obj_ids'][0][torch.argmax(data_dict['og3d_logits'][0]).item()]
175
  print(f"finish infernece result id is {result_id}")
176
- return result_id
 
 
 
 
177
 
178
  if __name__ == '__main__':
179
  inference("scene0050_00", "chair")
 
171
  model = Query3DUnified()
172
  load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False)
173
  data_dict = model(data_dict)
174
+ # calculate result id
175
  result_id = data_dict['obj_ids'][0][torch.argmax(data_dict['og3d_logits'][0]).item()]
176
  print(f"finish infernece result id is {result_id}")
177
+ # calculate langauge
178
+ tokenizer = AutoTokenizer.from_pretrained("t5-small")
179
+ response_pred = tokenizer.batch_decode(data_dict['generation_logits'], skip_special_tokens=True)[0]
180
+ print(f"text response is {response_pred}")
181
+ return result_id, response_pred
182
 
183
  if __name__ == '__main__':
184
  inference("scene0050_00", "chair")