anas-awadalla commited on
Commit
5842ec8
·
1 Parent(s): 231f7ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -61,7 +61,7 @@ model, image_processor, tokenizer = create_model_and_transforms(
61
 
62
  checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
63
  model.load_state_dict(torch.load(checkpoint_path), strict=False)
64
- model.eval().to("cuda", dtype=torch.bfloat16)
65
 
66
  def generate(
67
  idx,
@@ -152,7 +152,7 @@ def generate(
152
 
153
  # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
154
  output = model.generate(
155
- vision_x=vision_x.to("cuda", dtype=torch.bfloat16),
156
  lang_x=input_ids.to("cuda"),
157
  attention_mask=attention_mask.to("cuda"),
158
  max_new_tokens=30,
 
61
 
62
  checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
63
  model.load_state_dict(torch.load(checkpoint_path), strict=False)
64
+ model.eval().to("cuda")
65
 
66
  def generate(
67
  idx,
 
152
 
153
  # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
154
  output = model.generate(
155
+ vision_x=vision_x.to("cuda"),
156
  lang_x=input_ids.to("cuda"),
157
  attention_mask=attention_mask.to("cuda"),
158
  max_new_tokens=30,