import streamlit as st import os import torch import pandas as pd from PIL import Image from pylatexenc.latex2text import LatexNodes2Text from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Qwen2VLForConditionalGeneration, AutoProcessor ) from qwen_vl_utils import process_vision_info ############################# # Utility functions ############################# def convert_latex_to_plain_text(latex_string): converter = LatexNodes2Text() plain_text = converter.latex_to_text(latex_string) return plain_text ############################# # Caching model loads so they only happen once ############################# @st.cache_resource(show_spinner=False) def load_ocr_model(): # Load OCR model and processor model_ocr = Qwen2VLForConditionalGeneration.from_pretrained( "prithivMLmods/Qwen2-VL-OCR-2B-Instruct", torch_dtype="auto", device_map="auto" ) processor_ocr = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct") return model_ocr, processor_ocr @st.cache_resource(show_spinner=False) def load_llm_model(): # Load LLM model and tokenizer with BitsAndBytes 4-bit quantization configuration bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model_name = "deepseek-ai/deepseek-math-7b-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto" ) tokenizer.pad_token = tokenizer.eos_token return model, tokenizer ############################# # OCR & Expression solver functions ############################# def img_2_text(image, model_ocr, processor_ocr): # Prepare the conversation messages messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Derive the latex expression from the image given"} ], } ] # Generate the text prompt from the conversation template text = processor_ocr.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process vision inputs image_inputs, video_inputs = process_vision_info(messages) inputs = processor_ocr( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(model_ocr.device) generated_ids = model_ocr.generate(**inputs, max_new_tokens=512) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor_ocr.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0].split('<|im_end|>')[0] def expression_solver(expression, model_llm, tokenizer_llm): device = next(model_llm.parameters()).device prompt = f"""You are a helpful math assistant. Please analyze the problem carefully and provide a step-by-step solution. - If the problem is an equation, solve for the unknown variable(s). - If it is an expression, simplify it fully. - If it is a word problem, explain how you arrive at the result. - Output final value, either True or False in case of expressions where you have to verify, or the value of variables in expressions where you have to solve in a tag with no other text in it. Problem: {expression} Answer: """ inputs = tokenizer_llm(prompt, return_tensors="pt").to(device) outputs = model_llm.generate( **inputs, max_new_tokens=512, do_sample=True, top_p=0.95, temperature=0.7 ) generated_text = tokenizer_llm.decode(outputs[0], skip_special_tokens=True) return generated_text def process_images(images, model_ocr, processor_ocr, model_llm, tokenizer_llm): results = [] for image_file in images: # Open image with PIL image = Image.open(image_file) # Run OCR to get LaTeX string ocr_text = img_2_text(image, model_ocr, processor_ocr) # Convert LaTeX to plain text expression expression = convert_latex_to_plain_text(ocr_text) # Solve or simplify the expression using the LLM solution = expression_solver(expression, model_llm, tokenizer_llm) results.append({ "Filename": image_file.name, "OCR LaTeX": ocr_text, "Converted Expression": expression, "Solution": solution }) return results ############################# # Streamlit UI ############################# st.title("Math OCR & Solver") st.markdown( """ This app uses a Vision-Language OCR model to extract a LaTeX expression from an image, converts it to plain text, and then uses a language model to solve or simplify the expression. """ ) st.sidebar.header("Upload Images") uploaded_files = st.sidebar.file_uploader("Choose one or more images", type=["png", "jpg", "jpeg"], accept_multiple_files=True) if uploaded_files: st.subheader("Uploaded Images") for file in uploaded_files: st.image(file, caption=file.name, use_column_width=True) if st.button("Process Images"): with st.spinner("Loading models and processing images..."): # Load models once model_ocr, processor_ocr = load_ocr_model() model_llm, tokenizer_llm = load_llm_model() # Process each uploaded image results = process_images(uploaded_files, model_ocr, processor_ocr, model_llm, tokenizer_llm) # Display results in a table df_results = pd.DataFrame(results) st.success("Processing complete!") st.write(df_results) else: st.info("Please upload one or more images from the sidebar to begin.")