Nitin00043 commited on
Commit
45a182b
·
verified ·
1 Parent(s): 11d8425

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -12
app.py CHANGED
@@ -2,13 +2,8 @@ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
2
  import gradio as gr
3
  from PIL import Image
4
 
5
- # Use a public model identifier; change this if you have a different one or want to use a private model.
6
  model_name = "google/pix2struct-textcaps-base"
7
-
8
- # If you need authentication for a private repo, pass the token as follows:
9
- # model = Pix2StructForConditionalGeneration.from_pretrained(model_name, use_auth_token="YOUR_TOKEN")
10
- # processor = Pix2StructProcessor.from_pretrained(model_name, use_auth_token="YOUR_TOKEN")
11
-
12
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
13
  processor = Pix2StructProcessor.from_pretrained(model_name)
14
 
@@ -17,8 +12,7 @@ def solve_math_problem(image):
17
  # Ensure the image is in RGB format.
18
  image = image.convert("RGB")
19
 
20
- # Preprocess the image and text.
21
- # Note: We omit header_text since this is not a VQA task.
22
  inputs = processor(
23
  images=[image],
24
  text="Solve the following math problem:",
@@ -26,7 +20,7 @@ def solve_math_problem(image):
26
  max_patches=2048
27
  )
28
 
29
- # Generate the solution with specified generation parameters.
30
  predictions = model.generate(
31
  **inputs,
32
  max_new_tokens=200,
@@ -35,7 +29,7 @@ def solve_math_problem(image):
35
  temperature=0.2
36
  )
37
 
38
- # Decode the problem text and the generated solution.
39
  problem_text = processor.decode(
40
  inputs["input_ids"][0],
41
  skip_special_tokens=True,
@@ -58,8 +52,7 @@ demo = gr.Interface(
58
  inputs=gr.Image(
59
  type="pil",
60
  label="Upload Handwritten Math Problem",
61
- image_mode="RGB",
62
- source="upload"
63
  ),
64
  outputs=gr.Textbox(label="Solution", show_copy_button=True),
65
  title="Handwritten Math Problem Solver",
 
2
  import gradio as gr
3
  from PIL import Image
4
 
5
+ # Use a public model identifier. If you need a private model, remember to authenticate.
6
  model_name = "google/pix2struct-textcaps-base"
 
 
 
 
 
7
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
8
  processor = Pix2StructProcessor.from_pretrained(model_name)
9
 
 
12
  # Ensure the image is in RGB format.
13
  image = image.convert("RGB")
14
 
15
+ # Preprocess the image and text. Note that header_text is omitted as it's not used for non-VQA tasks.
 
16
  inputs = processor(
17
  images=[image],
18
  text="Solve the following math problem:",
 
20
  max_patches=2048
21
  )
22
 
23
+ # Generate the solution with generation parameters.
24
  predictions = model.generate(
25
  **inputs,
26
  max_new_tokens=200,
 
29
  temperature=0.2
30
  )
31
 
32
+ # Decode the problem text and generated solution.
33
  problem_text = processor.decode(
34
  inputs["input_ids"][0],
35
  skip_special_tokens=True,
 
52
  inputs=gr.Image(
53
  type="pil",
54
  label="Upload Handwritten Math Problem",
55
+ image_mode="RGB" # This forces the input to be RGB.
 
56
  ),
57
  outputs=gr.Textbox(label="Solution", show_copy_button=True),
58
  title="Handwritten Math Problem Solver",