import streamlit as st from inference.preprocess_image import ( image_to_np_arr, process_extracted_text, post_process_gen_outputs ) from inference.config import ( model_inf_inp_prompt, header_pattern, dots_pattern, DEVICE, model_name ) from typing import List, Tuple, Optional, AnyStr, Dict from transformers import AutoTokenizer, AutoModelForCausalLM import easyocr import time use_gpu = True if DEVICE.type == 'cpu': use_gpu = False @st.cache_resource def load_models(item_summarizer: AnyStr) -> Tuple: """ Function to load the models required for the inference process. Cached to avoid loading the models, every time the function is called. Parameters: item_summarizer: str, required -> The LLM model name to be used for item summarization. Returns: Tuple -> Tuple containing the required models for the inference process. """ # model to extract text from image text_extractor = easyocr.Reader(['en'], gpu = use_gpu ) # tokenizer and model to generate item summary tokenizer = AutoTokenizer.from_pretrained(item_summarizer) model = AutoModelForCausalLM.from_pretrained(item_summarizer) return (text_extractor, tokenizer, model) text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = model_name) # Define your extract_filter_img function async def extract_filter_img(image) -> Dict: """ 1. Convert Image to numpy array 2. Detect & Extract Text from Image - List of Tuples 3. Process text , to filter out irrelevant text 4. Classify only menu-related strings from detected text """ progress_bar = st.progress(0) status_message = st.empty() functions_messages = [ (image_to_np_arr, 'Converting Image to required format', 'Done Converting !'), (text_extractor.readtext, 'Extracting text from inp image', 'Done Extracting !'), (process_extracted_text, 'Clean Raw Extracted text', 'Done Cleaning !'), (classify_menu_text, 'Removing non-menu related text', 'Done removing !'), ] # Initialize variables result = image total_steps = len(functions_messages) ind_add_delays = [0, 2, 3, 4] # Loop through each function and execute it with status update for i, (func, start_message, end_message) in enumerate(functions_messages): status_message.write(start_message) if i in ind_add_delays: time.sleep(0.5) if i == 2: result = await func(result) else: result = func(result) status_message.write(end_message) # Update the progress bar progress_bar.progress((i + 1) / total_steps) if i in ind_add_delays: time.sleep(0.5) progress_bar.empty() status_message.empty() return result def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict: prompt_item = model_inf_inp_prompt.format(menu_text) input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids outputs = item_summarizer.generate(input_ids, max_new_tokens = 512, num_beams = 4, pad_token_id = item_tokenizer.pad_token_id, eos_token_id = item_tokenizer.eos_token_id, bos_token_id = item_tokenizer.bos_token_id ) prediction = item_tokenizer.batch_decode(outputs, skip_special_tokens=False ) postpro_output = post_process_gen_outputs( prediction, header_pattern, dots_pattern )[0] return postpro_output def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]: return extrc_str