anas-awadalla commited on
Commit
fe2a8a1
1 Parent(s): d25e2be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -62,7 +62,7 @@ model, image_processor, tokenizer = create_model_and_transforms(
62
  checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
63
  model.load_state_dict(torch.load(checkpoint_path), strict=False)
64
 
65
- model.eval()
66
 
67
  def generate(
68
  idx,
@@ -151,17 +151,17 @@ def generate(
151
  vision_x = vision_x.unsqueeze(1).unsqueeze(0)
152
  print(vision_x.shape)
153
 
154
- # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
155
- output = model.generate(
156
- vision_x=vision_x,
157
- lang_x=input_ids,
158
- attention_mask=attention_mask,
159
- max_new_tokens=30,
160
- num_beams=3,
161
- # do_sample=True,
162
- # temperature=0.3,
163
- # top_k=0,
164
- )
165
 
166
  gen_text = tokenizer.decode(
167
  output[0][len(input_ids[0]):], skip_special_tokens=True
 
62
  checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
63
  model.load_state_dict(torch.load(checkpoint_path), strict=False)
64
 
65
+ model.eval().to(0, dtype=torch.bfloat16)
66
 
67
  def generate(
68
  idx,
 
151
  vision_x = vision_x.unsqueeze(1).unsqueeze(0)
152
  print(vision_x.shape)
153
 
154
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
155
+ output = model.generate(
156
+ vision_x=vision_x.to(0, dtype=torch.bfloat16),
157
+ lang_x=input_ids.to(0),
158
+ attention_mask=attention_mask.to(0),
159
+ max_new_tokens=30,
160
+ num_beams=3,
161
+ # do_sample=True,
162
+ # temperature=0.3,
163
+ # top_k=0,
164
+ )
165
 
166
  gen_text = tokenizer.decode(
167
  output[0][len(input_ids[0]):], skip_special_tokens=True