Spaces:
Runtime error
Runtime error
# app.py | |
import gradio as gr | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
from PIL import Image | |
import sympy | |
# Load the pre-trained model and processor outside the function for efficiency | |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') | |
def solve_math_problem(image): | |
try: | |
# Ensure the image is in RGB format | |
image = image.convert("RGB") | |
# Resize and normalize the image as expected by the model | |
pixel_values = processor(images=image, return_tensors="pt").pixel_values | |
# Generate the text (this extracts the handwritten equation) | |
generated_ids = model.generate(pixel_values) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# Clean and prepare the extracted text | |
problem_text = generated_text.strip().replace(' ', '') | |
# Use sympy to parse and solve the equation | |
# Handle simple arithmetic and algebraic equations | |
expr = sympy.sympify(problem_text) | |
solution = sympy.solve(expr) | |
# Format the solution for display | |
if isinstance(solution, list): | |
solution = ', '.join([str(s) for s in solution]) | |
else: | |
solution = str(solution) | |
return f"**Problem:** {problem_text}\n\n**Solution:** {solution}" | |
except Exception as e: | |
return f"**Error processing image:** {str(e)}" | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=solve_math_problem, | |
inputs=gr.Image( | |
type="pil", | |
label="Upload Handwritten Math Problem", | |
image_mode="RGB" | |
), | |
outputs=gr.Markdown(), | |
title="Handwritten Math Problem Solver", | |
description="Upload an image of a handwritten math problem, and the app will attempt to solve it.", | |
examples=[ | |
["example_addition.png"], | |
["example_algebra.jpg"] | |
], | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |