File size: 6,127 Bytes
44d6f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import streamlit as st
import os
import torch
import pandas as pd
from PIL import Image
from pylatexenc.latex2text import LatexNodes2Text
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    Qwen2VLForConditionalGeneration,
    AutoProcessor
)
from qwen_vl_utils import process_vision_info

#############################
# Utility functions
#############################

def convert_latex_to_plain_text(latex_string):
    converter = LatexNodes2Text()
    plain_text = converter.latex_to_text(latex_string)
    return plain_text

#############################
# Caching model loads so they only happen once
#############################

@st.cache_resource(show_spinner=False)
def load_ocr_model():
    # Load OCR model and processor
    model_ocr = Qwen2VLForConditionalGeneration.from_pretrained(
        "prithivMLmods/Qwen2-VL-OCR-2B-Instruct", 
        torch_dtype="auto", 
        device_map="auto"
    )
    processor_ocr = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
    return model_ocr, processor_ocr

@st.cache_resource(show_spinner=False)
def load_llm_model():
    # Load LLM model and tokenizer with BitsAndBytes 4-bit quantization configuration
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,      
        bnb_4bit_quant_type="nf4",             
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model_name = "deepseek-ai/deepseek-math-7b-instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto"
    )
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

#############################
# OCR & Expression solver functions
#############################

def img_2_text(image, model_ocr, processor_ocr):
    # Prepare the conversation messages
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": "Derive the latex expression from the image given"}
            ],
        }
    ]
    
    # Generate the text prompt from the conversation template
    text = processor_ocr.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    # Process vision inputs
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor_ocr(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(model_ocr.device)
    
    generated_ids = model_ocr.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor_ocr.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text[0].split('<|im_end|>')[0]

def expression_solver(expression, model_llm, tokenizer_llm):
    device = next(model_llm.parameters()).device
    prompt = f"""You are a helpful math assistant. Please analyze the problem carefully and provide a step-by-step solution. 
- If the problem is an equation, solve for the unknown variable(s). 
- If it is an expression, simplify it fully. 
- If it is a word problem, explain how you arrive at the result.
- Output final value, either True or False in case of expressions where you have to verify, or the value of variables in expressions where you have to solve in a <ANS> </ANS> tag with no other text in it.

Problem: {expression}
Answer:
"""
    inputs = tokenizer_llm(prompt, return_tensors="pt").to(device)
    outputs = model_llm.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=True,
        top_p=0.95,
        temperature=0.7
    )
    generated_text = tokenizer_llm.decode(outputs[0], skip_special_tokens=True)
    return generated_text

def process_images(images, model_ocr, processor_ocr, model_llm, tokenizer_llm):
    results = []
    for image_file in images:
        # Open image with PIL
        image = Image.open(image_file)
        # Run OCR to get LaTeX string
        ocr_text = img_2_text(image, model_ocr, processor_ocr)
        # Convert LaTeX to plain text expression
        expression = convert_latex_to_plain_text(ocr_text)
        # Solve or simplify the expression using the LLM
        solution = expression_solver(expression, model_llm, tokenizer_llm)
        results.append({
            "Filename": image_file.name,
            "OCR LaTeX": ocr_text,
            "Converted Expression": expression,
            "Solution": solution
        })
    return results

#############################
# Streamlit UI
#############################

st.title("Math OCR & Solver")
st.markdown(
    """
    This app uses a Vision-Language OCR model to extract a LaTeX expression from an image,
    converts it to plain text, and then uses a language model to solve or simplify the expression.
    """
)

st.sidebar.header("Upload Images")
uploaded_files = st.sidebar.file_uploader("Choose one or more images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)

if uploaded_files:
    st.subheader("Uploaded Images")
    for file in uploaded_files:
        st.image(file, caption=file.name, use_column_width=True)
    
    if st.button("Process Images"):
        with st.spinner("Loading models and processing images..."):
            # Load models once
            model_ocr, processor_ocr = load_ocr_model()
            model_llm, tokenizer_llm = load_llm_model()
            
            # Process each uploaded image
            results = process_images(uploaded_files, model_ocr, processor_ocr, model_llm, tokenizer_llm)
            # Display results in a table
            df_results = pd.DataFrame(results)
            st.success("Processing complete!")
            st.write(df_results)
else:
    st.info("Please upload one or more images from the sidebar to begin.")