HenryShan commited on
Commit
f5cfe60
·
verified ·
1 Parent(s): c90d6e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -12,7 +12,7 @@ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
12
  tokenizer = vl_chat_processor.tokenizer
13
 
14
 
15
- def describe_image(image, user_question="Solve this AP Problem step by step and explain to the student who don't know how to solve it"):
16
  try:
17
  # Convert the PIL Image to a BytesIO object for compatibility
18
  image_byte_arr = BytesIO()
@@ -43,21 +43,31 @@ def describe_image(image, user_question="Solve this AP Problem step by step and
43
  force_batchify=True
44
  )
45
 
 
 
 
 
 
 
46
  # Load and prepare the model
47
  vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.float16).eval()
48
- vl_gpt = vl_gpt.to(torch.float16)
49
 
50
  # Generate embeddings from the image input
51
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs).to(dtype=torch.float16)
52
 
 
 
 
 
53
  print(f"Inputs Embeds dtype: {inputs_embeds.dtype}")
54
  print(f"Attention Mask dtype: {attention_mask.dtype}")
55
  print(f"Model dtype: {next(vl_gpt.parameters()).dtype}")
56
 
57
  # Generate the model's response
58
  outputs = vl_gpt.language_model.generate(
59
- inputs_embeds=inputs_embeds,
60
- attention_mask = prepare_inputs.attention_mask.to(vl_gpt.device).to(dtype=torch.float16),
61
  pad_token_id=tokenizer.eos_token_id,
62
  bos_token_id=tokenizer.bos_token_id,
63
  eos_token_id=tokenizer.eos_token_id,
@@ -65,16 +75,13 @@ def describe_image(image, user_question="Solve this AP Problem step by step and
65
  do_sample=False,
66
  use_cache=True
67
  )
68
- outputs = outputs.to(torch.float16)
69
 
70
  # Decode the generated tokens into text
71
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
72
  return answer
73
-
74
  except Exception as e:
75
  # Provide detailed error information
76
  return f"Error: {str(e)}"
77
-
78
  # Gradio interface
79
  def gradio_app():
80
  with gr.Blocks() as demo:
 
12
  tokenizer = vl_chat_processor.tokenizer
13
 
14
 
15
+ def describe_image(image, user_question="You are the best AP teacher in the world. Analyze the AP problem in the image, and solve it step by step to let a student who don't know how to solve it understand"):
16
  try:
17
  # Convert the PIL Image to a BytesIO object for compatibility
18
  image_byte_arr = BytesIO()
 
43
  force_batchify=True
44
  )
45
 
46
+ # Explicitly cast all tensors in prepare_inputs to torch.float16
47
+ prepare_inputs = {
48
+ k: v.to(torch.float16) if isinstance(v, torch.Tensor) else v
49
+ for k, v in prepare_inputs.items()
50
+ }
51
+
52
  # Load and prepare the model
53
  vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.float16).eval()
54
+ vl_gpt = vl_gpt.to(torch.float16) # Explicitly ensure all components are in float16
55
 
56
  # Generate embeddings from the image input
57
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs).to(dtype=torch.float16)
58
 
59
+ # Ensure attention mask is also in torch.float16
60
+ attention_mask = prepare_inputs["attention_mask"].to(vl_gpt.device).to(dtype=torch.float16)
61
+
62
+ # Debugging: Print tensor dtypes
63
  print(f"Inputs Embeds dtype: {inputs_embeds.dtype}")
64
  print(f"Attention Mask dtype: {attention_mask.dtype}")
65
  print(f"Model dtype: {next(vl_gpt.parameters()).dtype}")
66
 
67
  # Generate the model's response
68
  outputs = vl_gpt.language_model.generate(
69
+ inputs_embeds=inputs_embeds.to(torch.float16),
70
+ attention_mask=attention_mask.to(torch.float16),
71
  pad_token_id=tokenizer.eos_token_id,
72
  bos_token_id=tokenizer.bos_token_id,
73
  eos_token_id=tokenizer.eos_token_id,
 
75
  do_sample=False,
76
  use_cache=True
77
  )
 
78
 
79
  # Decode the generated tokens into text
80
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
81
  return answer
 
82
  except Exception as e:
83
  # Provide detailed error information
84
  return f"Error: {str(e)}"
 
85
  # Gradio interface
86
  def gradio_app():
87
  with gr.Blocks() as demo: