Nitin00043 commited on
Commit
11d8425
·
verified ·
1 Parent(s): 5093ea9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -2,8 +2,13 @@ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
2
  import gradio as gr
3
  from PIL import Image
4
 
5
- # Load the pre-trained Pix2Struct model and processor.
6
- model_name = "google/pix2struct-mathqa-base"
 
 
 
 
 
7
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
8
  processor = Pix2StructProcessor.from_pretrained(model_name)
9
 
@@ -13,12 +18,12 @@ def solve_math_problem(image):
13
  image = image.convert("RGB")
14
 
15
  # Preprocess the image and text.
16
- # Note: We omit the header_text parameter because this is not a VQA task.
17
  inputs = processor(
18
- images=[image], # Provide a list of images.
19
- text="Solve the following math problem:", # Prompt text.
20
  return_tensors="pt",
21
- max_patches=2048 # Increase the maximum patches for better math handling.
22
  )
23
 
24
  # Generate the solution with specified generation parameters.
@@ -30,8 +35,7 @@ def solve_math_problem(image):
30
  temperature=0.2
31
  )
32
 
33
- # Decode the input text and the model prediction.
34
- # Here, we access "input_ids" via the dictionary key.
35
  problem_text = processor.decode(
36
  inputs["input_ids"][0],
37
  skip_special_tokens=True,
@@ -54,14 +58,14 @@ demo = gr.Interface(
54
  inputs=gr.Image(
55
  type="pil",
56
  label="Upload Handwritten Math Problem",
57
- image_mode="RGB", # Force RGB conversion.
58
  source="upload"
59
  ),
60
  outputs=gr.Textbox(label="Solution", show_copy_button=True),
61
  title="Handwritten Math Problem Solver",
62
  description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution",
63
  examples=[
64
- ["example_addition.png"], # Ensure these example files exist in your working directory.
65
  ["example_algebra.jpg"]
66
  ],
67
  theme="soft",
 
2
  import gradio as gr
3
  from PIL import Image
4
 
5
+ # Use a public model identifier; change this if you have a different one or want to use a private model.
6
+ model_name = "google/pix2struct-textcaps-base"
7
+
8
+ # If you need authentication for a private repo, pass the token as follows:
9
+ # model = Pix2StructForConditionalGeneration.from_pretrained(model_name, use_auth_token="YOUR_TOKEN")
10
+ # processor = Pix2StructProcessor.from_pretrained(model_name, use_auth_token="YOUR_TOKEN")
11
+
12
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
13
  processor = Pix2StructProcessor.from_pretrained(model_name)
14
 
 
18
  image = image.convert("RGB")
19
 
20
  # Preprocess the image and text.
21
+ # Note: We omit header_text since this is not a VQA task.
22
  inputs = processor(
23
+ images=[image],
24
+ text="Solve the following math problem:",
25
  return_tensors="pt",
26
+ max_patches=2048
27
  )
28
 
29
  # Generate the solution with specified generation parameters.
 
35
  temperature=0.2
36
  )
37
 
38
+ # Decode the problem text and the generated solution.
 
39
  problem_text = processor.decode(
40
  inputs["input_ids"][0],
41
  skip_special_tokens=True,
 
58
  inputs=gr.Image(
59
  type="pil",
60
  label="Upload Handwritten Math Problem",
61
+ image_mode="RGB",
62
  source="upload"
63
  ),
64
  outputs=gr.Textbox(label="Solution", show_copy_button=True),
65
  title="Handwritten Math Problem Solver",
66
  description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution",
67
  examples=[
68
+ ["example_addition.png"],
69
  ["example_algebra.jpg"]
70
  ],
71
  theme="soft",