Nitin00043 commited on
Commit
5093ea9
·
verified ·
1 Parent(s): 3dc1b9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -2,24 +2,26 @@ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
2
  import gradio as gr
3
  from PIL import Image
4
 
5
- # Load the pre-trained Pix2Struct model and processor
6
  model_name = "google/pix2struct-mathqa-base"
7
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
8
  processor = Pix2StructProcessor.from_pretrained(model_name)
9
 
10
  def solve_math_problem(image):
11
  try:
12
- # Preprocess the image
13
- image = image.convert("RGB") # Ensure RGB format
 
 
 
14
  inputs = processor(
15
- images=[image], # Wrap in list
16
- text="Solve the following math problem:", # More specific prompt
17
  return_tensors="pt",
18
- max_patches=2048, # Increased from default 1024 for better math handling
19
- header_text="Math Problem" # Add header text
20
  )
21
 
22
- # Generate the solution
23
  predictions = model.generate(
24
  **inputs,
25
  max_new_tokens=200,
@@ -28,33 +30,38 @@ def solve_math_problem(image):
28
  temperature=0.2
29
  )
30
 
31
- # Decode the output
 
 
 
 
 
 
32
  solution = processor.decode(
33
- predictions[0],
34
  skip_special_tokens=True,
35
  clean_up_tokenization_spaces=True
36
  )
37
 
38
- # Format the solution
39
- return f"Problem: {processor.decode(inputs.input_ids[0])}\nSolution: {solution}"
40
 
41
  except Exception as e:
42
  return f"Error processing image: {str(e)}"
43
 
44
- # Gradio interface with explicit image handling
45
  demo = gr.Interface(
46
  fn=solve_math_problem,
47
  inputs=gr.Image(
48
  type="pil",
49
  label="Upload Handwritten Math Problem",
50
- image_mode="RGB", # Force RGB format
51
  source="upload"
52
  ),
53
  outputs=gr.Textbox(label="Solution", show_copy_button=True),
54
  title="Handwritten Math Problem Solver",
55
  description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution",
56
  examples=[
57
- ["example_addition.png"], # Make sure to upload these files
58
  ["example_algebra.jpg"]
59
  ],
60
  theme="soft",
@@ -62,4 +69,4 @@ demo = gr.Interface(
62
  )
63
 
64
  if __name__ == "__main__":
65
- demo.launch()
 
2
  import gradio as gr
3
  from PIL import Image
4
 
5
+ # Load the pre-trained Pix2Struct model and processor.
6
  model_name = "google/pix2struct-mathqa-base"
7
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
8
  processor = Pix2StructProcessor.from_pretrained(model_name)
9
 
10
  def solve_math_problem(image):
11
  try:
12
+ # Ensure the image is in RGB format.
13
+ image = image.convert("RGB")
14
+
15
+ # Preprocess the image and text.
16
+ # Note: We omit the header_text parameter because this is not a VQA task.
17
  inputs = processor(
18
+ images=[image], # Provide a list of images.
19
+ text="Solve the following math problem:", # Prompt text.
20
  return_tensors="pt",
21
+ max_patches=2048 # Increase the maximum patches for better math handling.
 
22
  )
23
 
24
+ # Generate the solution with specified generation parameters.
25
  predictions = model.generate(
26
  **inputs,
27
  max_new_tokens=200,
 
30
  temperature=0.2
31
  )
32
 
33
+ # Decode the input text and the model prediction.
34
+ # Here, we access "input_ids" via the dictionary key.
35
+ problem_text = processor.decode(
36
+ inputs["input_ids"][0],
37
+ skip_special_tokens=True,
38
+ clean_up_tokenization_spaces=True
39
+ )
40
  solution = processor.decode(
41
+ predictions[0],
42
  skip_special_tokens=True,
43
  clean_up_tokenization_spaces=True
44
  )
45
 
46
+ return f"Problem: {problem_text}\nSolution: {solution}"
 
47
 
48
  except Exception as e:
49
  return f"Error processing image: {str(e)}"
50
 
51
+ # Set up the Gradio interface.
52
  demo = gr.Interface(
53
  fn=solve_math_problem,
54
  inputs=gr.Image(
55
  type="pil",
56
  label="Upload Handwritten Math Problem",
57
+ image_mode="RGB", # Force RGB conversion.
58
  source="upload"
59
  ),
60
  outputs=gr.Textbox(label="Solution", show_copy_button=True),
61
  title="Handwritten Math Problem Solver",
62
  description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution",
63
  examples=[
64
+ ["example_addition.png"], # Ensure these example files exist in your working directory.
65
  ["example_algebra.jpg"]
66
  ],
67
  theme="soft",
 
69
  )
70
 
71
  if __name__ == "__main__":
72
+ demo.launch()