Nitin00043 commited on
Commit
11ef473
·
verified ·
1 Parent(s): 176916a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py CHANGED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ # Use a publicly available high-capacity model.
7
+ # For instance, we use "google/pix2struct-docvqa-large".
8
+ # (If you need a different model or a private one, adjust accordingly and add authentication if necessary.)
9
+ model_name = "google/pix2struct-docvqa-large"
10
+
11
+ model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
12
+ processor = Pix2StructProcessor.from_pretrained(model_name)
13
+
14
+ def solve_problem(image):
15
+ try:
16
+ # Ensure the image is in RGB.
17
+ image = image.convert("RGB")
18
+
19
+ # Preprocess image and text prompt.
20
+ inputs = processor(
21
+ images=[image],
22
+ text="Solve the following problem:",
23
+ return_tensors="pt",
24
+ max_patches=2048
25
+ )
26
+
27
+ # Generate prediction.
28
+ predictions = model.generate(
29
+ **inputs,
30
+ max_new_tokens=200,
31
+ early_stopping=True,
32
+ num_beams=4,
33
+ temperature=0.2
34
+ )
35
+
36
+ # Decode the prompt (input IDs) and the generated output.
37
+ problem_text = processor.decode(
38
+ inputs["input_ids"][0],
39
+ skip_special_tokens=True,
40
+ clean_up_tokenization_spaces=True
41
+ )
42
+ solution = processor.decode(
43
+ predictions[0],
44
+ skip_special_tokens=True,
45
+ clean_up_tokenization_spaces=True
46
+ )
47
+ return f"Problem: {problem_text}\nSolution: {solution}"
48
+ except Exception as e:
49
+ return f"Error processing image: {str(e)}"
50
+
51
+ # Set up the Gradio interface.
52
+ iface = gr.Interface(
53
+ fn=solve_problem,
54
+ inputs=gr.Image(type="pil", label="Upload Your Problem Image", image_mode="RGB"),
55
+ outputs=gr.Textbox(label="Solution", show_copy_button=True),
56
+ title="Problem Solver with Pix2Struct",
57
+ description=(
58
+ "Upload an image (for example, a handwritten math or logic problem) "
59
+ "and get a solution generated by a high-capacity Pix2Struct model.\n\n"
60
+ "Note: For best results on domain-specific tasks, consider fine-tuning on your own dataset."
61
+ ),
62
+ examples=[
63
+ ["example_problem1.png"],
64
+ ["example_problem2.jpg"]
65
+ ],
66
+ theme="soft",
67
+ allow_flagging="never"
68
+ )
69
+
70
+ if __name__ == "__main__":
71
+ iface.launch()