Amith Adiraju
commited on
Commit
·
5a29f4a
1
Parent(s):
11b899a
1. Added custom fine tuned model to provide item explanations is specific format.
Browse files2. Provided capability to enter menu items manually than uploading an image.
3. Created multiple pages and redirected code accordingly.
4. Added robust regular expressions and other techniques to post process outputs to user.
Signed-off-by: Amith Adiraju <[email protected]>
- app.py +37 -177
- inference/config.py +16 -26
- inference/preprocess_image.py +57 -4
- inference/translate.py +41 -16
- pages.py +214 -0
- utils.py +15 -0
app.py
CHANGED
@@ -1,204 +1,64 @@
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit import session_state as sst
|
3 |
-
from typing import List, Optional
|
4 |
import asyncio
|
5 |
-
import pandas as pd
|
6 |
-
|
7 |
-
from inference.translate import (
|
8 |
-
extract_filter_img,
|
9 |
-
transcribe_menu_model
|
10 |
-
)
|
11 |
-
|
12 |
-
from inference.config import DEBUG_MODE
|
13 |
-
from PIL import Image
|
14 |
-
import time
|
15 |
-
|
16 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
17 |
-
import os
|
18 |
-
|
19 |
-
# Setting workers to be 70% of all available virtual cpus in system
|
20 |
-
cpu_count = os.cpu_count()
|
21 |
-
pool = ThreadPoolExecutor(max_workers=int(cpu_count*0.7) )
|
22 |
|
|
|
23 |
# Initialize session state variable to start with home page
|
24 |
if "page" not in sst:
|
25 |
sst["page"] = "Home"
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
Parameters:
|
33 |
-
page: str, required.
|
34 |
-
|
35 |
-
Returns:
|
36 |
-
None
|
37 |
-
"""
|
38 |
-
|
39 |
-
sst["page"] = page
|
40 |
-
|
41 |
-
async def main_page() -> None:
|
42 |
-
"""
|
43 |
-
Function that contains content of main page i.e., image uploader and submit button to navigate to next page.
|
44 |
-
Upon submit , control goes to model inference 'page'.
|
45 |
-
|
46 |
-
Parameters:
|
47 |
-
None
|
48 |
-
|
49 |
-
Returns:
|
50 |
-
None
|
51 |
-
"""
|
52 |
-
|
53 |
-
# Streamlit app
|
54 |
-
first_title = st.empty()
|
55 |
-
first_title.title("App that explains your menu items ")
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
type=["jpg", "jpeg", "png"])
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
if uploaded_file is not None:
|
67 |
-
image = Image.open(uploaded_file)
|
68 |
-
|
69 |
-
# Only show if user wants to see
|
70 |
-
if st.checkbox('Show Uploaded Image'):
|
71 |
-
st.image(image,
|
72 |
-
caption='Uploaded Image',
|
73 |
-
use_column_width=True)
|
74 |
-
|
75 |
-
sst["input_image"] = image
|
76 |
-
|
77 |
-
# Submit button
|
78 |
-
st.button("Submit",
|
79 |
-
on_click = navigate_to,
|
80 |
-
args = ("Inference",))
|
81 |
-
|
82 |
-
|
83 |
-
st.info("""This application is for education purposes only. It uses AI, hence it's dietary
|
84 |
-
recommendations are not to be taken as medical advice, author doesn't bear responsibility
|
85 |
-
for incorrect dietary recommendations. Please proceed with caution.
|
86 |
-
""")
|
87 |
|
|
|
88 |
|
89 |
-
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
waiting for all threads to be done.
|
95 |
|
96 |
-
Parameters:
|
97 |
-
inp_texts: List[str], required -> List of strings, containing item names of a menu in english.
|
98 |
-
|
99 |
-
Returns:
|
100 |
-
None
|
101 |
-
"""
|
102 |
-
|
103 |
-
df = pd.DataFrame([('ITEM NAME', 'EXPLANATION')]
|
104 |
-
)
|
105 |
-
|
106 |
-
sl_table = st.table(df)
|
107 |
-
tp_futures = { pool.submit(transcribe_menu_model, mi): mi for mi in inp_texts }
|
108 |
|
109 |
-
for tpftr in as_completed(tp_futures):
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
try:
|
114 |
-
exp = tpftr.result()
|
115 |
-
sl_table.add_rows([(item,exp)] )
|
116 |
-
|
117 |
-
except Exception as e:
|
118 |
-
print("Could not add a new row dynamically, because of this error:", e)
|
119 |
-
|
120 |
-
return
|
121 |
-
|
122 |
-
|
123 |
-
async def model_inference():
|
124 |
-
|
125 |
"""
|
126 |
-
|
127 |
-
and toggles state between pages if needed.
|
128 |
|
129 |
-
Parameters:
|
130 |
-
None
|
131 |
Returns:
|
132 |
None
|
133 |
-
|
134 |
"""
|
135 |
-
|
136 |
-
second_title = st.empty()
|
137 |
-
second_title.title(" Using ML to explain your menu items ... ")
|
138 |
-
|
139 |
-
if "input_image" in sst:
|
140 |
-
|
141 |
-
image = sst["input_image"]
|
142 |
-
|
143 |
-
msg1 = st.empty()
|
144 |
-
msg1.write("Pre-processing and extracting text out of your image ....")
|
145 |
-
st_filter = time.perf_counter()
|
146 |
-
|
147 |
-
# Call the extract_filter_img function
|
148 |
-
filtered_text = await extract_filter_img(image)
|
149 |
-
en_filter = time.perf_counter()
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
elif num_items_detected > 0:
|
157 |
-
st.write(f"Detected {num_items_detected} menu items from your input image ... ")
|
158 |
-
|
159 |
-
msg2 = st.empty()
|
160 |
-
msg2.write("All pre-processing done, transcribing your menu items now ....")
|
161 |
-
st_trans_llm = time.perf_counter()
|
162 |
-
|
163 |
-
await dist_llm_inference(filtered_text)
|
164 |
-
|
165 |
-
msg3 = st.empty()
|
166 |
-
msg3.write("Done transcribing ... ")
|
167 |
-
en_trans_llm = time.perf_counter()
|
168 |
-
|
169 |
-
msg1.empty(); msg2.empty(); msg3.empty()
|
170 |
-
st.success("Image processed successfully! " )
|
171 |
-
|
172 |
-
if DEBUG_MODE:
|
173 |
-
filter_time_sec = en_filter - st_filter
|
174 |
-
llm_time_sec = en_trans_llm - st_trans_llm
|
175 |
-
total_time_sec = filter_time_sec + llm_time_sec
|
176 |
-
|
177 |
-
st.write("Time took to extract and filter text {}".format(filter_time_sec))
|
178 |
-
st.write("Time took to summarize by LLM {}".format(llm_time_sec))
|
179 |
-
st.write('Overall time taken in seconds: {}'.format(total_time_sec))
|
180 |
-
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
async def main():
|
191 |
-
"""
|
192 |
-
Function that toggles between pages based on state variables.
|
193 |
|
194 |
-
Parameters:
|
195 |
-
None
|
196 |
-
Returns:
|
197 |
-
None
|
198 |
-
"""
|
199 |
-
if sst["page"] == "Home":
|
200 |
-
await main_page()
|
201 |
elif sst["page"] == "Inference":
|
202 |
-
await
|
203 |
|
204 |
asyncio.run(main())
|
|
|
1 |
+
from utils import navigate_to
|
2 |
+
from pages import manual_input_page, image_input_page, model_inference_page
|
3 |
+
|
4 |
import streamlit as st
|
5 |
from streamlit import session_state as sst
|
|
|
6 |
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
#TODO: Fix model inference and post processing function befor emoving ot production.
|
9 |
# Initialize session state variable to start with home page
|
10 |
if "page" not in sst:
|
11 |
sst["page"] = "Home"
|
12 |
|
13 |
+
# function to remove all sesion variables from sst, except page.
|
14 |
+
def reset_sst():
|
15 |
+
for key in list(sst.keys()):
|
16 |
+
if key != "page":
|
17 |
+
sst.pop(key, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
# Landing page function
|
20 |
+
async def landing_page():
|
|
|
21 |
|
22 |
+
st.title("We will explain your menu like never before!")
|
23 |
+
st.write("\n")
|
24 |
+
st.write("\n")
|
25 |
+
st.write("\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
c1, c2= st.columns(2)
|
28 |
|
29 |
+
with c1:
|
30 |
+
# Navigate to manual input page if user clicks on the button
|
31 |
+
st.button("Enter Items Manually", on_click=navigate_to, args=("ManualInput",))
|
32 |
|
33 |
+
with c2:
|
34 |
+
# Navigate to image input page if user clicks on the button
|
35 |
+
st.button("Upload Items from Image", on_click=navigate_to, args=("ImageInput",))
|
|
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
|
|
38 |
|
39 |
+
# Main function to handle navigation
|
40 |
+
async def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
"""
|
42 |
+
Main function that handles the navigation logic based on the current page.
|
|
|
43 |
|
|
|
|
|
44 |
Returns:
|
45 |
None
|
|
|
46 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
# Navigation logic
|
49 |
+
if sst["page"] == "Home":
|
50 |
+
reset_sst() # reset all session state variables before navigating to the landing page
|
51 |
+
await landing_page() # Call the landing page function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
elif sst["page"] == "ManualInput":
|
54 |
+
reset_sst() # reset all session state variables before navigating to the landing page
|
55 |
+
await manual_input_page() # Call the manual input page function
|
56 |
+
|
57 |
+
elif sst["page"] == "ImageInput":
|
58 |
+
reset_sst() # reset all session state variables before navigating to the landing page
|
59 |
+
await image_input_page() # Call the image input page function
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
elif sst["page"] == "Inference":
|
62 |
+
await model_inference_page() # Call the model inference page function
|
63 |
|
64 |
asyncio.run(main())
|
inference/config.py
CHANGED
@@ -1,33 +1,23 @@
|
|
1 |
-
|
2 |
-
|
3 |
|
4 |
-
Item
|
5 |
-
|
6 |
-
|
7 |
-
It goes well with: White basmati rice or Indian flat bread.\n
|
8 |
-
Allergens: Paneer may cause digestive discomfort and intolerance to some.\n
|
9 |
-
Food Category: Vegetarian, Vegans may not like it, as paneer is usually made from cow milk.
|
10 |
|
|
|
|
|
11 |
|
12 |
-
Item -> rumali roti.\n
|
13 |
-
Explanation -> Major Ingredients here: roti.\n
|
14 |
-
How it is made: A small soft bread, made to size of a napkin ( a.k.a 'rumal' in hindi ); usually made with a combination of whole wheat and all purpose flour.\n
|
15 |
-
It goes well with: Most indian gravies such as palak paneer, tomato curry etc.\n
|
16 |
-
Allergens: May contain gluten, which is known to cause digestive discomfort and intolerance to some.\n
|
17 |
-
Food Category: Vegetarian, Vegan.
|
18 |
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
26 |
|
|
|
27 |
|
28 |
-
|
29 |
-
Item ->
|
30 |
-
"""
|
31 |
-
|
32 |
-
DEBUG_MODE = False
|
33 |
-
DEVICE = 'cpu'
|
|
|
1 |
+
import torch
|
2 |
+
import re
|
3 |
|
4 |
+
model_inf_inp_prompt = "INSTRUCTION: given food item name, explain these things:(major ingredients,making process,portion & spicy/sweet,pairs with,allergens,food type(veg/non-veg/vegan)). ensure to get allergens and food category factually correct.Item Name: {} "
|
5 |
+
header_pattern = r'Item Name: (.*?)\. Major Ingredients: (.*?)\. Making Process: (.*?)\. Portion and Spice Level: (.*?)\. Pairs With: (.*?)\. Allergens: (.*?)\. Food Type: (.*?)\.\s*</s>'
|
6 |
+
dots_pattern = re.compile(r'\.{3,}')
|
|
|
|
|
|
|
7 |
|
8 |
+
DEBUG_MODE = True
|
9 |
+
model_name = "AmithAdiraju1694/gpt-neo-125M_menuitemexp"
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
+
def get_device():
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
device = torch.device("cuda")
|
16 |
+
print(f"Using GPU: {torch.cuda.get_device_name(0)}") #get the name of the GPU being used.
|
17 |
+
else:
|
18 |
+
device = torch.device("cpu")
|
19 |
+
print("Using CPU")
|
20 |
|
21 |
+
return device
|
22 |
|
23 |
+
DEVICE = get_device()
|
|
|
|
|
|
|
|
|
|
inference/preprocess_image.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
|
2 |
import numpy as np
|
3 |
-
from typing import List, Tuple, Optional, AnyStr
|
4 |
import nltk
|
5 |
nltk.download("stopwords")
|
6 |
nltk.download('punkt')
|
@@ -53,11 +53,64 @@ def image_to_np_arr(image) -> np.array:
|
|
53 |
return np.array(image)
|
54 |
|
55 |
async def process_extracted_text(raw_extrc_text: List[Tuple]) -> List[AnyStr]:
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
output_texts = []
|
58 |
for _, extr_text, _ in raw_extrc_text:
|
59 |
# remove all numbers, special characters from a string
|
60 |
prcsd_txt = preprocess_text(extr_text)
|
61 |
-
if len(prcsd_txt.split(" ")
|
|
|
62 |
|
63 |
-
return output_texts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
import numpy as np
|
3 |
+
from typing import List, Tuple, Optional, AnyStr, Dict
|
4 |
import nltk
|
5 |
nltk.download("stopwords")
|
6 |
nltk.download('punkt')
|
|
|
53 |
return np.array(image)
|
54 |
|
55 |
async def process_extracted_text(raw_extrc_text: List[Tuple]) -> List[AnyStr]:
|
56 |
+
"""
|
57 |
+
Function that processes extracted text by removing numbers and special characters,
|
58 |
+
and filters out text with less than 2 words.
|
59 |
+
|
60 |
+
Parameters:
|
61 |
+
raw_extrc_text: List[Tuple], required -> A list of tuples containing extracted text.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
List[AnyStr] -> A list of processed text strings.
|
65 |
+
"""
|
66 |
output_texts = []
|
67 |
for _, extr_text, _ in raw_extrc_text:
|
68 |
# remove all numbers, special characters from a string
|
69 |
prcsd_txt = preprocess_text(extr_text)
|
70 |
+
if len(prcsd_txt.split(" ")) >= 2:
|
71 |
+
output_texts.append(prcsd_txt)
|
72 |
|
73 |
+
return output_texts
|
74 |
+
|
75 |
+
def post_process_gen_outputs(gen_output: List[str], header_pattern: str, dots_pattern:str) -> List[Dict]:
|
76 |
+
|
77 |
+
# Define the regular expression pattern to match section names and placeholders
|
78 |
+
headers = ["Item Name", "Major Ingredients", "Making Process", "Portion and Spice Level", "Pairs With", "Allergens", "Food Type"]
|
79 |
+
|
80 |
+
# Function to clean the strings
|
81 |
+
def clean_string(input_string):
|
82 |
+
parts = input_string.split(',')
|
83 |
+
cleaned_parts = [part.strip() for part in parts if part.strip()]
|
84 |
+
return ', '.join(cleaned_parts)
|
85 |
+
|
86 |
+
for i in range(len(gen_output)):
|
87 |
+
# Find all matches
|
88 |
+
matches = re.findall(header_pattern, gen_output[i])
|
89 |
+
|
90 |
+
# Since re.findall returns a list of tuples, we need to extract the first tuple
|
91 |
+
if matches:
|
92 |
+
result = dict(zip(headers,matches[0]))
|
93 |
+
result['Major Ingredients'] = clean_string(result['Major Ingredients'])
|
94 |
+
|
95 |
+
# if any of dictionary values strings are emtpy, replace it with string "Sorry, can't explain this."
|
96 |
+
for k in result.keys():
|
97 |
+
if len(result[k]) < 3 or any(header in result[k] for header in headers):
|
98 |
+
result[k] = "Sorry, can't explain this."
|
99 |
+
|
100 |
+
gen_output[i] = result
|
101 |
+
|
102 |
+
else:
|
103 |
+
if headers[1] in gen_output[i]:
|
104 |
+
|
105 |
+
gen_output[i] = {"May contain misleading explanation":
|
106 |
+
dots_pattern.sub('' ,
|
107 |
+
gen_output[i].split(headers[1]
|
108 |
+
)[1].strip().replace('</s>', '')
|
109 |
+
)
|
110 |
+
}
|
111 |
+
else:
|
112 |
+
gen_output[i] = {"Sorry, can't explain this item": "NA"}
|
113 |
+
|
114 |
+
gen_output[i].pop('Item Name', None)
|
115 |
+
return gen_output
|
116 |
+
|
inference/translate.py
CHANGED
@@ -2,29 +2,50 @@ import streamlit as st
|
|
2 |
|
3 |
from inference.preprocess_image import (
|
4 |
image_to_np_arr,
|
5 |
-
process_extracted_text
|
|
|
6 |
)
|
7 |
|
8 |
-
from inference.config import
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from typing import List, Tuple, Optional, AnyStr, Dict
|
10 |
-
from transformers import
|
11 |
import easyocr
|
12 |
import time
|
13 |
|
14 |
use_gpu = True
|
15 |
-
if DEVICE == 'cpu': use_gpu = False
|
16 |
|
17 |
@st.cache_resource
|
18 |
def load_models(item_summarizer: AnyStr) -> Tuple:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
text_extractor = easyocr.Reader(['en'],
|
20 |
gpu = use_gpu
|
21 |
)
|
22 |
-
|
23 |
-
model
|
|
|
|
|
24 |
|
25 |
return (text_extractor, tokenizer, model)
|
26 |
|
27 |
-
text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer =
|
28 |
|
29 |
|
30 |
# Define your extract_filter_img function
|
@@ -78,20 +99,24 @@ async def extract_filter_img(image) -> Dict:
|
|
78 |
|
79 |
def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:
|
80 |
|
81 |
-
prompt_item =
|
82 |
-
|
83 |
-
|
84 |
-
"""
|
85 |
input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
|
86 |
|
87 |
outputs = item_summarizer.generate(input_ids,
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
)
|
95 |
|
96 |
def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
|
97 |
return extrc_str
|
|
|
2 |
|
3 |
from inference.preprocess_image import (
|
4 |
image_to_np_arr,
|
5 |
+
process_extracted_text,
|
6 |
+
post_process_gen_outputs
|
7 |
)
|
8 |
|
9 |
+
from inference.config import (
|
10 |
+
model_inf_inp_prompt,
|
11 |
+
header_pattern,
|
12 |
+
dots_pattern,
|
13 |
+
DEVICE,
|
14 |
+
model_name
|
15 |
+
)
|
16 |
from typing import List, Tuple, Optional, AnyStr, Dict
|
17 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
18 |
import easyocr
|
19 |
import time
|
20 |
|
21 |
use_gpu = True
|
22 |
+
if DEVICE.type == 'cpu': use_gpu = False
|
23 |
|
24 |
@st.cache_resource
|
25 |
def load_models(item_summarizer: AnyStr) -> Tuple:
|
26 |
+
|
27 |
+
"""
|
28 |
+
Function to load the models required for the inference process. Cached to avoid loading the models, every time the function is called.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
item_summarizer: str, required -> The LLM model name to be used for item summarization.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Tuple -> Tuple containing the required models for the inference process.
|
35 |
+
"""
|
36 |
+
|
37 |
+
# model to extract text from image
|
38 |
text_extractor = easyocr.Reader(['en'],
|
39 |
gpu = use_gpu
|
40 |
)
|
41 |
+
|
42 |
+
# tokenizer and model to generate item summary
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(item_summarizer)
|
44 |
+
model = AutoModelForCausalLM.from_pretrained(item_summarizer)
|
45 |
|
46 |
return (text_extractor, tokenizer, model)
|
47 |
|
48 |
+
text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = model_name)
|
49 |
|
50 |
|
51 |
# Define your extract_filter_img function
|
|
|
99 |
|
100 |
def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:
|
101 |
|
102 |
+
prompt_item = model_inf_inp_prompt.format(menu_text)
|
|
|
|
|
|
|
103 |
input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
|
104 |
|
105 |
outputs = item_summarizer.generate(input_ids,
|
106 |
+
max_new_tokens = 512,
|
107 |
+
num_beams = 4,
|
108 |
+
pad_token_id = item_tokenizer.pad_token_id,
|
109 |
+
eos_token_id = item_tokenizer.eos_token_id,
|
110 |
+
bos_token_id = item_tokenizer.bos_token_id
|
111 |
+
)
|
112 |
+
|
113 |
+
prediction = item_tokenizer.batch_decode(outputs,
|
114 |
+
skip_special_tokens=False
|
115 |
)
|
116 |
|
117 |
+
postpro_output = post_process_gen_outputs( prediction, header_pattern, dots_pattern )[0]
|
118 |
+
|
119 |
+
return postpro_output
|
|
|
120 |
|
121 |
def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
|
122 |
return extrc_str
|
pages.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit import session_state as sst
|
3 |
+
|
4 |
+
|
5 |
+
from utils import navigate_to
|
6 |
+
from inference.config import DEBUG_MODE
|
7 |
+
|
8 |
+
from inference.translate import extract_filter_img, transcribe_menu_model,classify_menu_text
|
9 |
+
from inference.preprocess_image import preprocess_text
|
10 |
+
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import pandas as pd
|
14 |
+
from PIL import Image
|
15 |
+
from typing import List
|
16 |
+
import json
|
17 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
18 |
+
|
19 |
+
# Setting workers to be 70% of all available virtual cpus in system
|
20 |
+
cpu_count = os.cpu_count()
|
21 |
+
pool = ThreadPoolExecutor(max_workers=int(cpu_count*0.7) )
|
22 |
+
|
23 |
+
# Function that handles logic of explaining menu items from manual input
|
24 |
+
async def manual_input_page():
|
25 |
+
|
26 |
+
"""
|
27 |
+
Function that takes text input from user in input box of streamlit, user can add multiple text boxes and submit finally.
|
28 |
+
|
29 |
+
Parameters:
|
30 |
+
None
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
List[str]: List of strings, containing item names of a menu in english.
|
34 |
+
"""
|
35 |
+
|
36 |
+
st.write("This is the Manual Input Page.")
|
37 |
+
st.write("Once done, click on 'Explain My Menu' button to get explanations for each item ... ")
|
38 |
+
|
39 |
+
inp_texts = []
|
40 |
+
num_text_boxes = st.number_input("Number of text boxes", min_value=1, step=1)
|
41 |
+
for i in range(num_text_boxes):
|
42 |
+
text_box = st.text_input(f"Food item {i+1}")
|
43 |
+
if text_box:
|
44 |
+
inp_texts.append(text_box)
|
45 |
+
|
46 |
+
if len(inp_texts) > 0:
|
47 |
+
|
48 |
+
# Show user submit button only if they have entered some text and set text in session state
|
49 |
+
sst["user_entered_items"] = inp_texts
|
50 |
+
st.button("Explain My Menu",on_click=navigate_to,args=("Inference",))
|
51 |
+
|
52 |
+
else:
|
53 |
+
st.write("Please enter some items to proceed ...")
|
54 |
+
|
55 |
+
|
56 |
+
st.button("Go back Home", on_click=navigate_to, args=("Home",))
|
57 |
+
|
58 |
+
|
59 |
+
# Function that handles logic of explaining menu items from image uploads
|
60 |
+
async def image_input_page():
|
61 |
+
"""
|
62 |
+
Function that contains content of main page i.e., image uploader and submit button to navigate to next page.
|
63 |
+
Upon submit , control goes to model inference 'page'.
|
64 |
+
|
65 |
+
Parameters:
|
66 |
+
None
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
None
|
70 |
+
"""
|
71 |
+
|
72 |
+
st.write("This is the Image Input Page.")
|
73 |
+
|
74 |
+
# Streamlit function to upload an image from any device
|
75 |
+
uploaded_file = st.file_uploader("Choose an image...",
|
76 |
+
type=["jpg", "jpeg", "png"])
|
77 |
+
|
78 |
+
# Remove preivous states' value of input image if it exists
|
79 |
+
sst.pop('input_image', None)
|
80 |
+
|
81 |
+
# Submit button
|
82 |
+
if uploaded_file is not None:
|
83 |
+
image = Image.open(uploaded_file)
|
84 |
+
|
85 |
+
# Only show if user wants to see
|
86 |
+
if st.checkbox('Show Uploaded Image'):
|
87 |
+
st.image(image,
|
88 |
+
caption='Uploaded Image',
|
89 |
+
use_column_width=True)
|
90 |
+
|
91 |
+
sst["input_image"] = image
|
92 |
+
|
93 |
+
# Show user submit button only if they have uploaded an image
|
94 |
+
st.button("Translate My Menu",
|
95 |
+
on_click = navigate_to,
|
96 |
+
args = ("Inference",))
|
97 |
+
|
98 |
+
|
99 |
+
# Warning message to user
|
100 |
+
st.info("""This application is for education purposes only. It uses AI, hence it's dietary
|
101 |
+
recommendations are not to be taken as medical advice, author doesn't bear responsibility
|
102 |
+
for incorrect dietary recommendations. Please proceed with caution.
|
103 |
+
""")
|
104 |
+
|
105 |
+
# if user wants to go back, make sure to reset the session state
|
106 |
+
st.button("Go back Home", on_click=navigate_to, args=("Home",))
|
107 |
+
|
108 |
+
|
109 |
+
# Function that handles model inference
|
110 |
+
async def model_inference_page():
|
111 |
+
|
112 |
+
"""
|
113 |
+
Function that pre-processes input text from state variables, does concurrent inference
|
114 |
+
and toggles state between pages if needed.
|
115 |
+
|
116 |
+
Parameters:
|
117 |
+
None
|
118 |
+
Returns:
|
119 |
+
None
|
120 |
+
|
121 |
+
"""
|
122 |
+
|
123 |
+
second_title = st.empty()
|
124 |
+
second_title.title(" Using ML to explain your menu items ... ")
|
125 |
+
|
126 |
+
# User can either upload an image or enter text manually, we check for both
|
127 |
+
if "input_image" in sst:
|
128 |
+
image = sst["input_image"]
|
129 |
+
|
130 |
+
msg1 = st.empty()
|
131 |
+
msg1.write("Pre-processing and extracting text out of your image ....")
|
132 |
+
# Call the extract_filter_img function
|
133 |
+
filtered_text = await extract_filter_img(image)
|
134 |
+
num_items_detected = len(filtered_text)
|
135 |
+
|
136 |
+
|
137 |
+
if "user_entered_items" in sst:
|
138 |
+
user_text = sst["user_entered_items"]
|
139 |
+
st.write("Pre-processing and filtering text from user input ....")
|
140 |
+
|
141 |
+
filtered_text = [preprocess_text(ut) for ut in user_text]
|
142 |
+
|
143 |
+
num_items_detected = len(filtered_text)
|
144 |
+
|
145 |
+
|
146 |
+
# irrespective of source of user entry , we check if we have any items to process
|
147 |
+
if num_items_detected == 0:
|
148 |
+
st.write("We couldn't detect any menu items ( indian for now ) from your image, please try a different image by going back.")
|
149 |
+
|
150 |
+
elif num_items_detected > 0:
|
151 |
+
st.write(f"Detected {num_items_detected} menu items from your input image ... ")
|
152 |
+
|
153 |
+
msg2 = st.empty()
|
154 |
+
msg2.write("All pre-processing done, transcribing your menu items now ....")
|
155 |
+
st_trans_llm = time.perf_counter()
|
156 |
+
|
157 |
+
await dist_llm_inference(filtered_text)
|
158 |
+
|
159 |
+
msg3 = st.empty()
|
160 |
+
msg3.write("Done transcribing ... ")
|
161 |
+
en_trans_llm = time.perf_counter()
|
162 |
+
|
163 |
+
msg2.empty(); msg3.empty()
|
164 |
+
st.success("Image processed successfully! " )
|
165 |
+
|
166 |
+
# Some basic stats for debug mode
|
167 |
+
if DEBUG_MODE:
|
168 |
+
llm_time_sec = en_trans_llm - st_trans_llm
|
169 |
+
st.write("Time took to summarize by LLM {}".format(llm_time_sec))
|
170 |
+
|
171 |
+
|
172 |
+
# If user clicked in "translate_another" button reset all session state variables and go back to home
|
173 |
+
st.button("Go back Home", on_click=navigate_to, args=("Home",))
|
174 |
+
|
175 |
+
|
176 |
+
# Function that performs LLM inference on a single item
|
177 |
+
async def dist_llm_inference(inp_texts: List[str]) -> None:
|
178 |
+
|
179 |
+
"""
|
180 |
+
Function that performs concurrent LLM inference using threadpool. It displays
|
181 |
+
results of those threads that are done with execution, as a dynamic row to streamlit table, rather than
|
182 |
+
waiting for all threads to be done.
|
183 |
+
|
184 |
+
Parameters:
|
185 |
+
inp_texts: List[str], required -> List of strings, containing item names of a menu in english.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
None
|
189 |
+
"""
|
190 |
+
|
191 |
+
df = pd.DataFrame([('ITEM NAME', 'EXPLANATION')]
|
192 |
+
)
|
193 |
+
|
194 |
+
sl_table = st.table(df)
|
195 |
+
tp_futures = { pool.submit(transcribe_menu_model, mi): mi for mi in inp_texts }
|
196 |
+
|
197 |
+
for tpftr in as_completed(tp_futures):
|
198 |
+
|
199 |
+
item = tp_futures[tpftr]
|
200 |
+
|
201 |
+
try:
|
202 |
+
exp = tpftr.result()
|
203 |
+
|
204 |
+
|
205 |
+
sl_table.add_rows([(item,
|
206 |
+
str(exp ))
|
207 |
+
]
|
208 |
+
)
|
209 |
+
|
210 |
+
except Exception as e:
|
211 |
+
print("Could not add a new row dynamically, because of this error:", e)
|
212 |
+
|
213 |
+
return
|
214 |
+
|
utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from streamlit import session_state as sst
|
3 |
+
def navigate_to(page: str) -> None:
|
4 |
+
"""
|
5 |
+
Function to set the current page in the state of streamlit. A helper for
|
6 |
+
simulating navigation in streamlit.
|
7 |
+
|
8 |
+
Parameters:
|
9 |
+
page: str, required.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
None
|
13 |
+
"""
|
14 |
+
|
15 |
+
sst["page"] = page
|