Nitin00043 commited on
Commit
de612dc
·
verified ·
1 Parent(s): d48368d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -63
app.py CHANGED
@@ -1,71 +1,71 @@
1
- import torch
2
- from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
3
- import gradio as gr
4
- from PIL import Image
5
 
6
- # Use a publicly available high-capacity model.
7
- # For instance, we use "google/pix2struct-docvqa-large".
8
- # (If you need a different model or a private one, adjust accordingly and add authentication if necessary.)
9
- model_name = "google/pix2struct-docvqa-large"
10
 
11
- model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
12
- processor = Pix2StructProcessor.from_pretrained(model_name)
13
 
14
- def solve_problem(image):
15
- try:
16
- # Ensure the image is in RGB.
17
- image = image.convert("RGB")
18
 
19
- # Preprocess image and text prompt.
20
- inputs = processor(
21
- images=[image],
22
- text="Solve the following problem:",
23
- return_tensors="pt",
24
- max_patches=2048
25
- )
26
 
27
- # Generate prediction.
28
- predictions = model.generate(
29
- **inputs,
30
- max_new_tokens=200,
31
- early_stopping=True,
32
- num_beams=4,
33
- temperature=0.2
34
- )
35
 
36
- # Decode the prompt (input IDs) and the generated output.
37
- problem_text = processor.decode(
38
- inputs["input_ids"][0],
39
- skip_special_tokens=True,
40
- clean_up_tokenization_spaces=True
41
- )
42
- solution = processor.decode(
43
- predictions[0],
44
- skip_special_tokens=True,
45
- clean_up_tokenization_spaces=True
46
- )
47
- return f"Problem: {problem_text}\nSolution: {solution}"
48
- except Exception as e:
49
- return f"Error processing image: {str(e)}"
50
 
51
- # Set up the Gradio interface.
52
- iface = gr.Interface(
53
- fn=solve_problem,
54
- inputs=gr.Image(type="pil", label="Upload Your Problem Image", image_mode="RGB"),
55
- outputs=gr.Textbox(label="Solution", show_copy_button=True),
56
- title="Problem Solver with Pix2Struct",
57
- description=(
58
- "Upload an image (for example, a handwritten math or logic problem) "
59
- "and get a solution generated by a high-capacity Pix2Struct model.\n\n"
60
- "Note: For best results on domain-specific tasks, consider fine-tuning on your own dataset."
61
- ),
62
- examples=[
63
- ["example_problem1.png"],
64
- ["example_problem2.jpg"]
65
- ],
66
- theme="soft",
67
- allow_flagging="never"
68
- )
69
 
70
- if __name__ == "__main__":
71
- iface.launch()
 
1
+ # import torch
2
+ # from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
3
+ # import gradio as gr
4
+ # from PIL import Image
5
 
6
+ # # Use a publicly available high-capacity model.
7
+ # # For instance, we use "google/pix2struct-docvqa-large".
8
+ # # (If you need a different model or a private one, adjust accordingly and add authentication if necessary.)
9
+ # model_name = "google/pix2struct-docvqa-large"
10
 
11
+ # model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
12
+ # processor = Pix2StructProcessor.from_pretrained(model_name)
13
 
14
+ # def solve_problem(image):
15
+ # try:
16
+ # # Ensure the image is in RGB.
17
+ # image = image.convert("RGB")
18
 
19
+ # # Preprocess image and text prompt.
20
+ # inputs = processor(
21
+ # images=[image],
22
+ # text="Solve the following problem:",
23
+ # return_tensors="pt",
24
+ # max_patches=2048
25
+ # )
26
 
27
+ # # Generate prediction.
28
+ # predictions = model.generate(
29
+ # **inputs,
30
+ # max_new_tokens=200,
31
+ # early_stopping=True,
32
+ # num_beams=4,
33
+ # temperature=0.2
34
+ # )
35
 
36
+ # # Decode the prompt (input IDs) and the generated output.
37
+ # problem_text = processor.decode(
38
+ # inputs["input_ids"][0],
39
+ # skip_special_tokens=True,
40
+ # clean_up_tokenization_spaces=True
41
+ # )
42
+ # solution = processor.decode(
43
+ # predictions[0],
44
+ # skip_special_tokens=True,
45
+ # clean_up_tokenization_spaces=True
46
+ # )
47
+ # return f"Problem: {problem_text}\nSolution: {solution}"
48
+ # except Exception as e:
49
+ # return f"Error processing image: {str(e)}"
50
 
51
+ # # Set up the Gradio interface.
52
+ # iface = gr.Interface(
53
+ # fn=solve_problem,
54
+ # inputs=gr.Image(type="pil", label="Upload Your Problem Image", image_mode="RGB"),
55
+ # outputs=gr.Textbox(label="Solution", show_copy_button=True),
56
+ # title="Problem Solver with Pix2Struct",
57
+ # description=(
58
+ # "Upload an image (for example, a handwritten math or logic problem) "
59
+ # "and get a solution generated by a high-capacity Pix2Struct model.\n\n"
60
+ # "Note: For best results on domain-specific tasks, consider fine-tuning on your own dataset."
61
+ # ),
62
+ # examples=[
63
+ # ["example_problem1.png"],
64
+ # ["example_problem2.jpg"]
65
+ # ],
66
+ # theme="soft",
67
+ # allow_flagging="never"
68
+ # )
69
 
70
+ # if __name__ == "__main__":
71
+ # iface.launch()