|
import streamlit as st |
|
|
|
from inference.preprocess_image import ( |
|
image_to_np_arr, |
|
process_extracted_text, |
|
post_process_gen_outputs |
|
) |
|
|
|
from inference.config import ( |
|
model_inf_inp_prompt, |
|
header_pattern, |
|
dots_pattern, |
|
DEVICE, |
|
model_name |
|
) |
|
from typing import List, Tuple, Optional, AnyStr, Dict |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import easyocr |
|
import time |
|
|
|
use_gpu = True |
|
if DEVICE.type == 'cpu': use_gpu = False |
|
|
|
@st.cache_resource |
|
def load_models(item_summarizer: AnyStr) -> Tuple: |
|
|
|
""" |
|
Function to load the models required for the inference process. Cached to avoid loading the models, every time the function is called. |
|
|
|
Parameters: |
|
item_summarizer: str, required -> The LLM model name to be used for item summarization. |
|
|
|
Returns: |
|
Tuple -> Tuple containing the required models for the inference process. |
|
""" |
|
|
|
|
|
text_extractor = easyocr.Reader(['en'], |
|
gpu = use_gpu |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(item_summarizer) |
|
model = AutoModelForCausalLM.from_pretrained(item_summarizer) |
|
|
|
return (text_extractor, tokenizer, model) |
|
|
|
text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = model_name) |
|
|
|
|
|
|
|
async def extract_filter_img(image) -> Dict: |
|
|
|
""" |
|
1. Convert Image to numpy array |
|
2. Detect & Extract Text from Image - List of Tuples |
|
3. Process text , to filter out irrelevant text |
|
4. Classify only menu-related strings from detected text |
|
|
|
""" |
|
|
|
progress_bar = st.progress(0) |
|
status_message = st.empty() |
|
|
|
functions_messages = [ |
|
(image_to_np_arr, 'Converting Image to required format', 'Done Converting !'), |
|
(text_extractor.readtext, 'Extracting text from inp image', 'Done Extracting !'), |
|
(process_extracted_text, 'Clean Raw Extracted text', 'Done Cleaning !'), |
|
(classify_menu_text, 'Removing non-menu related text', 'Done removing !'), |
|
] |
|
|
|
|
|
result = image |
|
total_steps = len(functions_messages) |
|
ind_add_delays = [0, 2, 3, 4] |
|
|
|
|
|
for i, (func, start_message, end_message) in enumerate(functions_messages): |
|
status_message.write(start_message) |
|
|
|
if i in ind_add_delays: |
|
time.sleep(0.5) |
|
|
|
if i == 2: result = await func(result) |
|
else: result = func(result) |
|
|
|
status_message.write(end_message) |
|
|
|
|
|
progress_bar.progress((i + 1) / total_steps) |
|
|
|
if i in ind_add_delays: |
|
time.sleep(0.5) |
|
|
|
progress_bar.empty() |
|
status_message.empty() |
|
return result |
|
|
|
|
|
def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict: |
|
|
|
prompt_item = model_inf_inp_prompt.format(menu_text) |
|
input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids |
|
|
|
outputs = item_summarizer.generate(input_ids, |
|
max_new_tokens = 512, |
|
num_beams = 4, |
|
pad_token_id = item_tokenizer.pad_token_id, |
|
eos_token_id = item_tokenizer.eos_token_id, |
|
bos_token_id = item_tokenizer.bos_token_id |
|
) |
|
|
|
prediction = item_tokenizer.batch_decode(outputs, |
|
skip_special_tokens=False |
|
) |
|
|
|
postpro_output = post_process_gen_outputs( prediction, header_pattern, dots_pattern )[0] |
|
|
|
return postpro_output |
|
|
|
def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]: |
|
return extrc_str |
|
|
|
|