Techies / app.py
majorSeaweed's picture
Create app.py
44d6f7f verified
raw
history blame
6.13 kB
import streamlit as st
import os
import torch
import pandas as pd
from PIL import Image
from pylatexenc.latex2text import LatexNodes2Text
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
Qwen2VLForConditionalGeneration,
AutoProcessor
)
from qwen_vl_utils import process_vision_info
#############################
# Utility functions
#############################
def convert_latex_to_plain_text(latex_string):
converter = LatexNodes2Text()
plain_text = converter.latex_to_text(latex_string)
return plain_text
#############################
# Caching model loads so they only happen once
#############################
@st.cache_resource(show_spinner=False)
def load_ocr_model():
# Load OCR model and processor
model_ocr = Qwen2VLForConditionalGeneration.from_pretrained(
"prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
torch_dtype="auto",
device_map="auto"
)
processor_ocr = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
return model_ocr, processor_ocr
@st.cache_resource(show_spinner=False)
def load_llm_model():
# Load LLM model and tokenizer with BitsAndBytes 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model_name = "deepseek-ai/deepseek-math-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto"
)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
#############################
# OCR & Expression solver functions
#############################
def img_2_text(image, model_ocr, processor_ocr):
# Prepare the conversation messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Derive the latex expression from the image given"}
],
}
]
# Generate the text prompt from the conversation template
text = processor_ocr.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Process vision inputs
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor_ocr(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model_ocr.device)
generated_ids = model_ocr.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor_ocr.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0].split('<|im_end|>')[0]
def expression_solver(expression, model_llm, tokenizer_llm):
device = next(model_llm.parameters()).device
prompt = f"""You are a helpful math assistant. Please analyze the problem carefully and provide a step-by-step solution.
- If the problem is an equation, solve for the unknown variable(s).
- If it is an expression, simplify it fully.
- If it is a word problem, explain how you arrive at the result.
- Output final value, either True or False in case of expressions where you have to verify, or the value of variables in expressions where you have to solve in a <ANS> </ANS> tag with no other text in it.
Problem: {expression}
Answer:
"""
inputs = tokenizer_llm(prompt, return_tensors="pt").to(device)
outputs = model_llm.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
temperature=0.7
)
generated_text = tokenizer_llm.decode(outputs[0], skip_special_tokens=True)
return generated_text
def process_images(images, model_ocr, processor_ocr, model_llm, tokenizer_llm):
results = []
for image_file in images:
# Open image with PIL
image = Image.open(image_file)
# Run OCR to get LaTeX string
ocr_text = img_2_text(image, model_ocr, processor_ocr)
# Convert LaTeX to plain text expression
expression = convert_latex_to_plain_text(ocr_text)
# Solve or simplify the expression using the LLM
solution = expression_solver(expression, model_llm, tokenizer_llm)
results.append({
"Filename": image_file.name,
"OCR LaTeX": ocr_text,
"Converted Expression": expression,
"Solution": solution
})
return results
#############################
# Streamlit UI
#############################
st.title("Math OCR & Solver")
st.markdown(
"""
This app uses a Vision-Language OCR model to extract a LaTeX expression from an image,
converts it to plain text, and then uses a language model to solve or simplify the expression.
"""
)
st.sidebar.header("Upload Images")
uploaded_files = st.sidebar.file_uploader("Choose one or more images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
if uploaded_files:
st.subheader("Uploaded Images")
for file in uploaded_files:
st.image(file, caption=file.name, use_column_width=True)
if st.button("Process Images"):
with st.spinner("Loading models and processing images..."):
# Load models once
model_ocr, processor_ocr = load_ocr_model()
model_llm, tokenizer_llm = load_llm_model()
# Process each uploaded image
results = process_images(uploaded_files, model_ocr, processor_ocr, model_llm, tokenizer_llm)
# Display results in a table
df_results = pd.DataFrame(results)
st.success("Processing complete!")
st.write(df_results)
else:
st.info("Please upload one or more images from the sidebar to begin.")