Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
|
12 |
tokenizer = vl_chat_processor.tokenizer
|
13 |
|
14 |
|
15 |
-
def describe_image(image, user_question="
|
16 |
try:
|
17 |
# Convert the PIL Image to a BytesIO object for compatibility
|
18 |
image_byte_arr = BytesIO()
|
@@ -43,21 +43,31 @@ def describe_image(image, user_question="Solve this AP Problem step by step and
|
|
43 |
force_batchify=True
|
44 |
)
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
# Load and prepare the model
|
47 |
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.float16).eval()
|
48 |
-
vl_gpt = vl_gpt.to(torch.float16)
|
49 |
|
50 |
# Generate embeddings from the image input
|
51 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs).to(dtype=torch.float16)
|
52 |
|
|
|
|
|
|
|
|
|
53 |
print(f"Inputs Embeds dtype: {inputs_embeds.dtype}")
|
54 |
print(f"Attention Mask dtype: {attention_mask.dtype}")
|
55 |
print(f"Model dtype: {next(vl_gpt.parameters()).dtype}")
|
56 |
|
57 |
# Generate the model's response
|
58 |
outputs = vl_gpt.language_model.generate(
|
59 |
-
inputs_embeds=inputs_embeds,
|
60 |
-
attention_mask
|
61 |
pad_token_id=tokenizer.eos_token_id,
|
62 |
bos_token_id=tokenizer.bos_token_id,
|
63 |
eos_token_id=tokenizer.eos_token_id,
|
@@ -65,16 +75,13 @@ def describe_image(image, user_question="Solve this AP Problem step by step and
|
|
65 |
do_sample=False,
|
66 |
use_cache=True
|
67 |
)
|
68 |
-
outputs = outputs.to(torch.float16)
|
69 |
|
70 |
# Decode the generated tokens into text
|
71 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
72 |
return answer
|
73 |
-
|
74 |
except Exception as e:
|
75 |
# Provide detailed error information
|
76 |
return f"Error: {str(e)}"
|
77 |
-
|
78 |
# Gradio interface
|
79 |
def gradio_app():
|
80 |
with gr.Blocks() as demo:
|
|
|
12 |
tokenizer = vl_chat_processor.tokenizer
|
13 |
|
14 |
|
15 |
+
def describe_image(image, user_question="You are the best AP teacher in the world. Analyze the AP problem in the image, and solve it step by step to let a student who don't know how to solve it understand"):
|
16 |
try:
|
17 |
# Convert the PIL Image to a BytesIO object for compatibility
|
18 |
image_byte_arr = BytesIO()
|
|
|
43 |
force_batchify=True
|
44 |
)
|
45 |
|
46 |
+
# Explicitly cast all tensors in prepare_inputs to torch.float16
|
47 |
+
prepare_inputs = {
|
48 |
+
k: v.to(torch.float16) if isinstance(v, torch.Tensor) else v
|
49 |
+
for k, v in prepare_inputs.items()
|
50 |
+
}
|
51 |
+
|
52 |
# Load and prepare the model
|
53 |
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.float16).eval()
|
54 |
+
vl_gpt = vl_gpt.to(torch.float16) # Explicitly ensure all components are in float16
|
55 |
|
56 |
# Generate embeddings from the image input
|
57 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs).to(dtype=torch.float16)
|
58 |
|
59 |
+
# Ensure attention mask is also in torch.float16
|
60 |
+
attention_mask = prepare_inputs["attention_mask"].to(vl_gpt.device).to(dtype=torch.float16)
|
61 |
+
|
62 |
+
# Debugging: Print tensor dtypes
|
63 |
print(f"Inputs Embeds dtype: {inputs_embeds.dtype}")
|
64 |
print(f"Attention Mask dtype: {attention_mask.dtype}")
|
65 |
print(f"Model dtype: {next(vl_gpt.parameters()).dtype}")
|
66 |
|
67 |
# Generate the model's response
|
68 |
outputs = vl_gpt.language_model.generate(
|
69 |
+
inputs_embeds=inputs_embeds.to(torch.float16),
|
70 |
+
attention_mask=attention_mask.to(torch.float16),
|
71 |
pad_token_id=tokenizer.eos_token_id,
|
72 |
bos_token_id=tokenizer.bos_token_id,
|
73 |
eos_token_id=tokenizer.eos_token_id,
|
|
|
75 |
do_sample=False,
|
76 |
use_cache=True
|
77 |
)
|
|
|
78 |
|
79 |
# Decode the generated tokens into text
|
80 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
81 |
return answer
|
|
|
82 |
except Exception as e:
|
83 |
# Provide detailed error information
|
84 |
return f"Error: {str(e)}"
|
|
|
85 |
# Gradio interface
|
86 |
def gradio_app():
|
87 |
with gr.Blocks() as demo:
|