File size: 3,973 Bytes
9a0f501 2a12b77 9a0f501 2a12b77 9a0f501 2a12b77 9a0f501 2a12b77 9a0f501 11b899a 2a12b77 11b899a 2a12b77 11b899a 2a12b77 11b899a 9a0f501 11b899a 9a0f501 11b899a 9a0f501 11b899a 9a0f501 2a12b77 11b899a 9a0f501 11b899a 2a12b77 11b899a 2a12b77 9a0f501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import streamlit as st
from inference.preprocess_image import (
image_to_np_arr,
process_extracted_text,
post_process_gen_outputs
)
from inference.config import (
model_inf_inp_prompt,
header_pattern,
dots_pattern,
DEVICE,
model_name
)
from typing import List, Tuple, Optional, AnyStr, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM
import easyocr
import time
use_gpu = True
if DEVICE.type == 'cpu': use_gpu = False
@st.cache_resource
def load_models(item_summarizer: AnyStr) -> Tuple:
"""
Function to load the models required for the inference process. Cached to avoid loading the models, every time the function is called.
Parameters:
item_summarizer: str, required -> The LLM model name to be used for item summarization.
Returns:
Tuple -> Tuple containing the required models for the inference process.
"""
# model to extract text from image
text_extractor = easyocr.Reader(['en'],
gpu = use_gpu
)
# tokenizer and model to generate item summary
tokenizer = AutoTokenizer.from_pretrained(item_summarizer)
model = AutoModelForCausalLM.from_pretrained(item_summarizer)
return (text_extractor, tokenizer, model)
text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = model_name)
# Define your extract_filter_img function
async def extract_filter_img(image) -> Dict:
"""
1. Convert Image to numpy array
2. Detect & Extract Text from Image - List of Tuples
3. Process text , to filter out irrelevant text
4. Classify only menu-related strings from detected text
"""
progress_bar = st.progress(0)
status_message = st.empty()
functions_messages = [
(image_to_np_arr, 'Converting Image to required format', 'Done Converting !'),
(text_extractor.readtext, 'Extracting text from inp image', 'Done Extracting !'),
(process_extracted_text, 'Clean Raw Extracted text', 'Done Cleaning !'),
(classify_menu_text, 'Removing non-menu related text', 'Done removing !'),
]
# Initialize variables
result = image
total_steps = len(functions_messages)
ind_add_delays = [0, 2, 3, 4]
# Loop through each function and execute it with status update
for i, (func, start_message, end_message) in enumerate(functions_messages):
status_message.write(start_message)
if i in ind_add_delays:
time.sleep(0.5)
if i == 2: result = await func(result)
else: result = func(result)
status_message.write(end_message)
# Update the progress bar
progress_bar.progress((i + 1) / total_steps)
if i in ind_add_delays:
time.sleep(0.5)
progress_bar.empty()
status_message.empty()
return result
def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:
prompt_item = model_inf_inp_prompt.format(menu_text)
input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
outputs = item_summarizer.generate(input_ids,
max_new_tokens = 512,
num_beams = 4,
pad_token_id = item_tokenizer.pad_token_id,
eos_token_id = item_tokenizer.eos_token_id,
bos_token_id = item_tokenizer.bos_token_id
)
prediction = item_tokenizer.batch_decode(outputs,
skip_special_tokens=False
)
postpro_output = post_process_gen_outputs( prediction, header_pattern, dots_pattern )[0]
return postpro_output
def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
return extrc_str
|