File size: 3,973 Bytes
9a0f501
 
 
 
2a12b77
 
9a0f501
 
2a12b77
 
 
 
 
 
 
9a0f501
2a12b77
9a0f501
 
 
 
2a12b77
9a0f501
11b899a
 
2a12b77
 
 
 
 
 
 
 
 
 
 
 
11b899a
 
 
2a12b77
 
 
 
11b899a
 
 
2a12b77
11b899a
9a0f501
 
11b899a
9a0f501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11b899a
 
9a0f501
 
 
 
 
 
 
 
 
 
 
 
 
 
11b899a
9a0f501
2a12b77
11b899a
9a0f501
11b899a
2a12b77
 
 
 
 
 
 
 
 
11b899a
 
2a12b77
 
 
9a0f501
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st

from inference.preprocess_image import (
    image_to_np_arr,
    process_extracted_text,
    post_process_gen_outputs
)

from inference.config import (
     model_inf_inp_prompt, 
    header_pattern, 
    dots_pattern,
    DEVICE,
    model_name
                             )
from typing import List, Tuple, Optional, AnyStr, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM
import easyocr
import time

use_gpu = True
if DEVICE.type == 'cpu': use_gpu = False

@st.cache_resource
def load_models(item_summarizer: AnyStr) -> Tuple:

    """
    Function to load the models required for the inference process. Cached to avoid loading the models, every time the function is called.
    
    Parameters:
        item_summarizer: str, required -> The LLM model name to be used for item summarization.
    
    Returns:
        Tuple -> Tuple containing the required models for the inference process.
    """

    # model to extract text from image 
    text_extractor = easyocr.Reader(['en'],
                                    gpu = use_gpu
                                    )
    
    # tokenizer and model to generate item summary
    tokenizer = AutoTokenizer.from_pretrained(item_summarizer)
    model = AutoModelForCausalLM.from_pretrained(item_summarizer)

    return (text_extractor, tokenizer, model)

text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = model_name)


# Define your extract_filter_img function
async def extract_filter_img(image) -> Dict:

    """
    1. Convert Image to numpy array
    2. Detect & Extract Text from Image - List of Tuples
    3. Process text , to filter out irrelevant text
    4. Classify only menu-related strings from detected text
    
    """
    
    progress_bar = st.progress(0)
    status_message = st.empty()

    functions_messages = [
        (image_to_np_arr, 'Converting Image to required format', 'Done Converting !'),
        (text_extractor.readtext, 'Extracting text from inp image', 'Done Extracting !'),
        (process_extracted_text, 'Clean Raw Extracted text', 'Done Cleaning !'),
        (classify_menu_text, 'Removing non-menu related text', 'Done removing !'),
    ]
    
    # Initialize variables
    result = image
    total_steps = len(functions_messages)
    ind_add_delays = [0, 2, 3, 4]

    # Loop through each function and execute it with status update
    for i, (func, start_message, end_message) in enumerate(functions_messages):
        status_message.write(start_message)

        if i in ind_add_delays:
            time.sleep(0.5)

        if i == 2: result = await func(result)
        else: result = func(result)
        
        status_message.write(end_message)

        # Update the progress bar
        progress_bar.progress((i + 1) / total_steps)

        if i in ind_add_delays:
            time.sleep(0.5)

    progress_bar.empty()
    status_message.empty()
    return result


def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:

    prompt_item = model_inf_inp_prompt.format(menu_text)
    input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
    
    outputs = item_summarizer.generate(input_ids,
                                       max_new_tokens = 512,
                                       num_beams = 4,
                                       pad_token_id = item_tokenizer.pad_token_id,
                                       eos_token_id = item_tokenizer.eos_token_id,
                                       bos_token_id = item_tokenizer.bos_token_id
                                       )
    
    prediction = item_tokenizer.batch_decode(outputs,
                                        skip_special_tokens=False
                                        )
    
    postpro_output = post_process_gen_outputs( prediction, header_pattern, dots_pattern )[0]

    return postpro_output

def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
    return extrc_str