majorSeaweed commited on
Commit
44d6f7f
·
verified ·
1 Parent(s): bdddd8e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import torch
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from pylatexenc.latex2text import LatexNodes2Text
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ BitsAndBytesConfig,
11
+ Qwen2VLForConditionalGeneration,
12
+ AutoProcessor
13
+ )
14
+ from qwen_vl_utils import process_vision_info
15
+
16
+ #############################
17
+ # Utility functions
18
+ #############################
19
+
20
+ def convert_latex_to_plain_text(latex_string):
21
+ converter = LatexNodes2Text()
22
+ plain_text = converter.latex_to_text(latex_string)
23
+ return plain_text
24
+
25
+ #############################
26
+ # Caching model loads so they only happen once
27
+ #############################
28
+
29
+ @st.cache_resource(show_spinner=False)
30
+ def load_ocr_model():
31
+ # Load OCR model and processor
32
+ model_ocr = Qwen2VLForConditionalGeneration.from_pretrained(
33
+ "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
34
+ torch_dtype="auto",
35
+ device_map="auto"
36
+ )
37
+ processor_ocr = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
38
+ return model_ocr, processor_ocr
39
+
40
+ @st.cache_resource(show_spinner=False)
41
+ def load_llm_model():
42
+ # Load LLM model and tokenizer with BitsAndBytes 4-bit quantization configuration
43
+ bnb_config = BitsAndBytesConfig(
44
+ load_in_4bit=True,
45
+ bnb_4bit_use_double_quant=True,
46
+ bnb_4bit_quant_type="nf4",
47
+ bnb_4bit_compute_dtype=torch.bfloat16
48
+ )
49
+ model_name = "deepseek-ai/deepseek-math-7b-instruct"
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ model_name,
53
+ quantization_config=bnb_config,
54
+ device_map="auto"
55
+ )
56
+ tokenizer.pad_token = tokenizer.eos_token
57
+ return model, tokenizer
58
+
59
+ #############################
60
+ # OCR & Expression solver functions
61
+ #############################
62
+
63
+ def img_2_text(image, model_ocr, processor_ocr):
64
+ # Prepare the conversation messages
65
+ messages = [
66
+ {
67
+ "role": "user",
68
+ "content": [
69
+ {"type": "image", "image": image},
70
+ {"type": "text", "text": "Derive the latex expression from the image given"}
71
+ ],
72
+ }
73
+ ]
74
+
75
+ # Generate the text prompt from the conversation template
76
+ text = processor_ocr.apply_chat_template(
77
+ messages, tokenize=False, add_generation_prompt=True
78
+ )
79
+ # Process vision inputs
80
+ image_inputs, video_inputs = process_vision_info(messages)
81
+ inputs = processor_ocr(
82
+ text=[text],
83
+ images=image_inputs,
84
+ videos=video_inputs,
85
+ padding=True,
86
+ return_tensors="pt",
87
+ )
88
+ inputs = inputs.to(model_ocr.device)
89
+
90
+ generated_ids = model_ocr.generate(**inputs, max_new_tokens=512)
91
+ generated_ids_trimmed = [
92
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
93
+ ]
94
+ output_text = processor_ocr.batch_decode(
95
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
96
+ )
97
+ return output_text[0].split('<|im_end|>')[0]
98
+
99
+ def expression_solver(expression, model_llm, tokenizer_llm):
100
+ device = next(model_llm.parameters()).device
101
+ prompt = f"""You are a helpful math assistant. Please analyze the problem carefully and provide a step-by-step solution.
102
+ - If the problem is an equation, solve for the unknown variable(s).
103
+ - If it is an expression, simplify it fully.
104
+ - If it is a word problem, explain how you arrive at the result.
105
+ - Output final value, either True or False in case of expressions where you have to verify, or the value of variables in expressions where you have to solve in a <ANS> </ANS> tag with no other text in it.
106
+
107
+ Problem: {expression}
108
+ Answer:
109
+ """
110
+ inputs = tokenizer_llm(prompt, return_tensors="pt").to(device)
111
+ outputs = model_llm.generate(
112
+ **inputs,
113
+ max_new_tokens=512,
114
+ do_sample=True,
115
+ top_p=0.95,
116
+ temperature=0.7
117
+ )
118
+ generated_text = tokenizer_llm.decode(outputs[0], skip_special_tokens=True)
119
+ return generated_text
120
+
121
+ def process_images(images, model_ocr, processor_ocr, model_llm, tokenizer_llm):
122
+ results = []
123
+ for image_file in images:
124
+ # Open image with PIL
125
+ image = Image.open(image_file)
126
+ # Run OCR to get LaTeX string
127
+ ocr_text = img_2_text(image, model_ocr, processor_ocr)
128
+ # Convert LaTeX to plain text expression
129
+ expression = convert_latex_to_plain_text(ocr_text)
130
+ # Solve or simplify the expression using the LLM
131
+ solution = expression_solver(expression, model_llm, tokenizer_llm)
132
+ results.append({
133
+ "Filename": image_file.name,
134
+ "OCR LaTeX": ocr_text,
135
+ "Converted Expression": expression,
136
+ "Solution": solution
137
+ })
138
+ return results
139
+
140
+ #############################
141
+ # Streamlit UI
142
+ #############################
143
+
144
+ st.title("Math OCR & Solver")
145
+ st.markdown(
146
+ """
147
+ This app uses a Vision-Language OCR model to extract a LaTeX expression from an image,
148
+ converts it to plain text, and then uses a language model to solve or simplify the expression.
149
+ """
150
+ )
151
+
152
+ st.sidebar.header("Upload Images")
153
+ uploaded_files = st.sidebar.file_uploader("Choose one or more images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
154
+
155
+ if uploaded_files:
156
+ st.subheader("Uploaded Images")
157
+ for file in uploaded_files:
158
+ st.image(file, caption=file.name, use_column_width=True)
159
+
160
+ if st.button("Process Images"):
161
+ with st.spinner("Loading models and processing images..."):
162
+ # Load models once
163
+ model_ocr, processor_ocr = load_ocr_model()
164
+ model_llm, tokenizer_llm = load_llm_model()
165
+
166
+ # Process each uploaded image
167
+ results = process_images(uploaded_files, model_ocr, processor_ocr, model_llm, tokenizer_llm)
168
+ # Display results in a table
169
+ df_results = pd.DataFrame(results)
170
+ st.success("Processing complete!")
171
+ st.write(df_results)
172
+ else:
173
+ st.info("Please upload one or more images from the sidebar to begin.")