Spaces:
Sleeping
Sleeping
File size: 6,127 Bytes
44d6f7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import streamlit as st
import os
import torch
import pandas as pd
from PIL import Image
from pylatexenc.latex2text import LatexNodes2Text
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
Qwen2VLForConditionalGeneration,
AutoProcessor
)
from qwen_vl_utils import process_vision_info
#############################
# Utility functions
#############################
def convert_latex_to_plain_text(latex_string):
converter = LatexNodes2Text()
plain_text = converter.latex_to_text(latex_string)
return plain_text
#############################
# Caching model loads so they only happen once
#############################
@st.cache_resource(show_spinner=False)
def load_ocr_model():
# Load OCR model and processor
model_ocr = Qwen2VLForConditionalGeneration.from_pretrained(
"prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
torch_dtype="auto",
device_map="auto"
)
processor_ocr = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
return model_ocr, processor_ocr
@st.cache_resource(show_spinner=False)
def load_llm_model():
# Load LLM model and tokenizer with BitsAndBytes 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model_name = "deepseek-ai/deepseek-math-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto"
)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
#############################
# OCR & Expression solver functions
#############################
def img_2_text(image, model_ocr, processor_ocr):
# Prepare the conversation messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Derive the latex expression from the image given"}
],
}
]
# Generate the text prompt from the conversation template
text = processor_ocr.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Process vision inputs
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor_ocr(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model_ocr.device)
generated_ids = model_ocr.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor_ocr.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0].split('<|im_end|>')[0]
def expression_solver(expression, model_llm, tokenizer_llm):
device = next(model_llm.parameters()).device
prompt = f"""You are a helpful math assistant. Please analyze the problem carefully and provide a step-by-step solution.
- If the problem is an equation, solve for the unknown variable(s).
- If it is an expression, simplify it fully.
- If it is a word problem, explain how you arrive at the result.
- Output final value, either True or False in case of expressions where you have to verify, or the value of variables in expressions where you have to solve in a <ANS> </ANS> tag with no other text in it.
Problem: {expression}
Answer:
"""
inputs = tokenizer_llm(prompt, return_tensors="pt").to(device)
outputs = model_llm.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
temperature=0.7
)
generated_text = tokenizer_llm.decode(outputs[0], skip_special_tokens=True)
return generated_text
def process_images(images, model_ocr, processor_ocr, model_llm, tokenizer_llm):
results = []
for image_file in images:
# Open image with PIL
image = Image.open(image_file)
# Run OCR to get LaTeX string
ocr_text = img_2_text(image, model_ocr, processor_ocr)
# Convert LaTeX to plain text expression
expression = convert_latex_to_plain_text(ocr_text)
# Solve or simplify the expression using the LLM
solution = expression_solver(expression, model_llm, tokenizer_llm)
results.append({
"Filename": image_file.name,
"OCR LaTeX": ocr_text,
"Converted Expression": expression,
"Solution": solution
})
return results
#############################
# Streamlit UI
#############################
st.title("Math OCR & Solver")
st.markdown(
"""
This app uses a Vision-Language OCR model to extract a LaTeX expression from an image,
converts it to plain text, and then uses a language model to solve or simplify the expression.
"""
)
st.sidebar.header("Upload Images")
uploaded_files = st.sidebar.file_uploader("Choose one or more images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
if uploaded_files:
st.subheader("Uploaded Images")
for file in uploaded_files:
st.image(file, caption=file.name, use_column_width=True)
if st.button("Process Images"):
with st.spinner("Loading models and processing images..."):
# Load models once
model_ocr, processor_ocr = load_ocr_model()
model_llm, tokenizer_llm = load_llm_model()
# Process each uploaded image
results = process_images(uploaded_files, model_ocr, processor_ocr, model_llm, tokenizer_llm)
# Display results in a table
df_results = pd.DataFrame(results)
st.success("Processing complete!")
st.write(df_results)
else:
st.info("Please upload one or more images from the sidebar to begin.")
|