Spaces:
Runtime error
Runtime error
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
import gradio as gr | |
from PIL import Image | |
# Load the pre-trained Pix2Struct model and processor. | |
model_name = "google/pix2struct-mathqa-base" | |
model = Pix2StructForConditionalGeneration.from_pretrained(model_name) | |
processor = Pix2StructProcessor.from_pretrained(model_name) | |
def solve_math_problem(image): | |
try: | |
# Ensure the image is in RGB format. | |
image = image.convert("RGB") | |
# Preprocess the image and text. | |
# Note: We omit the header_text parameter because this is not a VQA task. | |
inputs = processor( | |
images=[image], # Provide a list of images. | |
text="Solve the following math problem:", # Prompt text. | |
return_tensors="pt", | |
max_patches=2048 # Increase the maximum patches for better math handling. | |
) | |
# Generate the solution with specified generation parameters. | |
predictions = model.generate( | |
**inputs, | |
max_new_tokens=200, | |
early_stopping=True, | |
num_beams=4, | |
temperature=0.2 | |
) | |
# Decode the input text and the model prediction. | |
# Here, we access "input_ids" via the dictionary key. | |
problem_text = processor.decode( | |
inputs["input_ids"][0], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
solution = processor.decode( | |
predictions[0], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
return f"Problem: {problem_text}\nSolution: {solution}" | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
# Set up the Gradio interface. | |
demo = gr.Interface( | |
fn=solve_math_problem, | |
inputs=gr.Image( | |
type="pil", | |
label="Upload Handwritten Math Problem", | |
image_mode="RGB", # Force RGB conversion. | |
source="upload" | |
), | |
outputs=gr.Textbox(label="Solution", show_copy_button=True), | |
title="Handwritten Math Problem Solver", | |
description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution", | |
examples=[ | |
["example_addition.png"], # Ensure these example files exist in your working directory. | |
["example_algebra.jpg"] | |
], | |
theme="soft", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |