prithivMLmods commited on
Commit
b0ba3ed
·
verified ·
1 Parent(s): 4890db6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -50,15 +50,22 @@ class Model:
50
  self.pipe.to(self.device)
51
  # Ensure the text encoder is in half precision to avoid dtype mismatches.
52
  if torch.cuda.is_available():
53
- self.pipe.text_encoder = self.pipe.text_encoder.half()
 
 
 
54
 
55
  self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
56
  self.pipe_img.to(self.device)
 
57
  if torch.cuda.is_available():
58
- self.pipe_img.text_encoder = self.pipe_img.text_encoder.half()
 
 
59
 
60
  def to_glb(self, ply_path: str) -> str:
61
  mesh = trimesh.load(ply_path)
 
62
  rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
63
  mesh.apply_transform(rot)
64
  rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
@@ -447,4 +454,4 @@ demo = gr.ChatInterface(
447
 
448
  if __name__ == "__main__":
449
  # To create a public link, set share=True in launch().
450
- demo.queue(max_size=20).launch(share=True)
 
50
  self.pipe.to(self.device)
51
  # Ensure the text encoder is in half precision to avoid dtype mismatches.
52
  if torch.cuda.is_available():
53
+ try:
54
+ self.pipe.text_encoder = self.pipe.text_encoder.half()
55
+ except AttributeError:
56
+ pass
57
 
58
  self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
59
  self.pipe_img.to(self.device)
60
+ # Use getattr with a default value to avoid AttributeError if text_encoder is missing.
61
  if torch.cuda.is_available():
62
+ text_encoder_img = getattr(self.pipe_img, "text_encoder", None)
63
+ if text_encoder_img is not None:
64
+ self.pipe_img.text_encoder = text_encoder_img.half()
65
 
66
  def to_glb(self, ply_path: str) -> str:
67
  mesh = trimesh.load(ply_path)
68
+ # Rotate the mesh for proper orientation
69
  rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
70
  mesh.apply_transform(rot)
71
  rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
 
454
 
455
  if __name__ == "__main__":
456
  # To create a public link, set share=True in launch().
457
+ demo.queue(max_size=20).launch(share=True)