AmithAdiraju1694's picture
feat_preo_cmod (#3)
2a12b77 verified
import streamlit as st
from inference.preprocess_image import (
image_to_np_arr,
process_extracted_text,
post_process_gen_outputs
)
from inference.config import (
model_inf_inp_prompt,
header_pattern,
dots_pattern,
DEVICE,
model_name
)
from typing import List, Tuple, Optional, AnyStr, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM
import easyocr
import time
use_gpu = True
if DEVICE.type == 'cpu': use_gpu = False
@st.cache_resource
def load_models(item_summarizer: AnyStr) -> Tuple:
"""
Function to load the models required for the inference process. Cached to avoid loading the models, every time the function is called.
Parameters:
item_summarizer: str, required -> The LLM model name to be used for item summarization.
Returns:
Tuple -> Tuple containing the required models for the inference process.
"""
# model to extract text from image
text_extractor = easyocr.Reader(['en'],
gpu = use_gpu
)
# tokenizer and model to generate item summary
tokenizer = AutoTokenizer.from_pretrained(item_summarizer)
model = AutoModelForCausalLM.from_pretrained(item_summarizer)
return (text_extractor, tokenizer, model)
text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = model_name)
# Define your extract_filter_img function
async def extract_filter_img(image) -> Dict:
"""
1. Convert Image to numpy array
2. Detect & Extract Text from Image - List of Tuples
3. Process text , to filter out irrelevant text
4. Classify only menu-related strings from detected text
"""
progress_bar = st.progress(0)
status_message = st.empty()
functions_messages = [
(image_to_np_arr, 'Converting Image to required format', 'Done Converting !'),
(text_extractor.readtext, 'Extracting text from inp image', 'Done Extracting !'),
(process_extracted_text, 'Clean Raw Extracted text', 'Done Cleaning !'),
(classify_menu_text, 'Removing non-menu related text', 'Done removing !'),
]
# Initialize variables
result = image
total_steps = len(functions_messages)
ind_add_delays = [0, 2, 3, 4]
# Loop through each function and execute it with status update
for i, (func, start_message, end_message) in enumerate(functions_messages):
status_message.write(start_message)
if i in ind_add_delays:
time.sleep(0.5)
if i == 2: result = await func(result)
else: result = func(result)
status_message.write(end_message)
# Update the progress bar
progress_bar.progress((i + 1) / total_steps)
if i in ind_add_delays:
time.sleep(0.5)
progress_bar.empty()
status_message.empty()
return result
def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:
prompt_item = model_inf_inp_prompt.format(menu_text)
input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
outputs = item_summarizer.generate(input_ids,
max_new_tokens = 512,
num_beams = 4,
pad_token_id = item_tokenizer.pad_token_id,
eos_token_id = item_tokenizer.eos_token_id,
bos_token_id = item_tokenizer.bos_token_id
)
prediction = item_tokenizer.batch_decode(outputs,
skip_special_tokens=False
)
postpro_output = post_process_gen_outputs( prediction, header_pattern, dots_pattern )[0]
return postpro_output
def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
return extrc_str