Amith Adiraju commited on
Commit
5a29f4a
·
1 Parent(s): 11b899a

1. Added custom fine tuned model to provide item explanations is specific format.

Browse files

2. Provided capability to enter menu items manually than uploading an image.
3. Created multiple pages and redirected code accordingly.
4. Added robust regular expressions and other techniques to post process outputs to user.

Signed-off-by: Amith Adiraju <[email protected]>

Files changed (6) hide show
  1. app.py +37 -177
  2. inference/config.py +16 -26
  3. inference/preprocess_image.py +57 -4
  4. inference/translate.py +41 -16
  5. pages.py +214 -0
  6. utils.py +15 -0
app.py CHANGED
@@ -1,204 +1,64 @@
 
 
 
1
  import streamlit as st
2
  from streamlit import session_state as sst
3
- from typing import List, Optional
4
  import asyncio
5
- import pandas as pd
6
-
7
- from inference.translate import (
8
- extract_filter_img,
9
- transcribe_menu_model
10
- )
11
-
12
- from inference.config import DEBUG_MODE
13
- from PIL import Image
14
- import time
15
-
16
- from concurrent.futures import ThreadPoolExecutor, as_completed
17
- import os
18
-
19
- # Setting workers to be 70% of all available virtual cpus in system
20
- cpu_count = os.cpu_count()
21
- pool = ThreadPoolExecutor(max_workers=int(cpu_count*0.7) )
22
 
 
23
  # Initialize session state variable to start with home page
24
  if "page" not in sst:
25
  sst["page"] = "Home"
26
 
27
- def navigate_to(page: str) -> None:
28
- """
29
- Function to set the current page in the state of streamlit. A helper for
30
- simulating navigation in streamlit.
31
-
32
- Parameters:
33
- page: str, required.
34
-
35
- Returns:
36
- None
37
- """
38
-
39
- sst["page"] = page
40
-
41
- async def main_page() -> None:
42
- """
43
- Function that contains content of main page i.e., image uploader and submit button to navigate to next page.
44
- Upon submit , control goes to model inference 'page'.
45
-
46
- Parameters:
47
- None
48
-
49
- Returns:
50
- None
51
- """
52
-
53
- # Streamlit app
54
- first_title = st.empty()
55
- first_title.title("App that explains your menu items ")
56
-
57
 
58
- # Streamlit function to upload an image from any device
59
- uploaded_file = st.file_uploader("Choose an image...",
60
- type=["jpg", "jpeg", "png"])
61
 
62
- # Remove preivous states' value of input image if it exists
63
- sst.pop('input_image', None)
64
-
65
- # Submit button
66
- if uploaded_file is not None:
67
- image = Image.open(uploaded_file)
68
-
69
- # Only show if user wants to see
70
- if st.checkbox('Show Uploaded Image'):
71
- st.image(image,
72
- caption='Uploaded Image',
73
- use_column_width=True)
74
-
75
- sst["input_image"] = image
76
-
77
- # Submit button
78
- st.button("Submit",
79
- on_click = navigate_to,
80
- args = ("Inference",))
81
-
82
-
83
- st.info("""This application is for education purposes only. It uses AI, hence it's dietary
84
- recommendations are not to be taken as medical advice, author doesn't bear responsibility
85
- for incorrect dietary recommendations. Please proceed with caution.
86
- """)
87
 
 
88
 
89
- async def dist_llm_inference(inp_texts: List[str]) -> None:
 
 
90
 
91
- """
92
- Function that performs concurrent LLM inference using threadpool. It displays
93
- results of those threads that are done with execution, as a dynamic row to streamlit table, rather than
94
- waiting for all threads to be done.
95
 
96
- Parameters:
97
- inp_texts: List[str], required -> List of strings, containing item names of a menu in english.
98
-
99
- Returns:
100
- None
101
- """
102
-
103
- df = pd.DataFrame([('ITEM NAME', 'EXPLANATION')]
104
- )
105
-
106
- sl_table = st.table(df)
107
- tp_futures = { pool.submit(transcribe_menu_model, mi): mi for mi in inp_texts }
108
 
109
- for tpftr in as_completed(tp_futures):
110
 
111
- item = tp_futures[tpftr]
112
-
113
- try:
114
- exp = tpftr.result()
115
- sl_table.add_rows([(item,exp)] )
116
-
117
- except Exception as e:
118
- print("Could not add a new row dynamically, because of this error:", e)
119
-
120
- return
121
-
122
-
123
- async def model_inference():
124
-
125
  """
126
- Function that pre-processes input text from state variables, does concurrent inference
127
- and toggles state between pages if needed.
128
 
129
- Parameters:
130
- None
131
  Returns:
132
  None
133
-
134
  """
135
-
136
- second_title = st.empty()
137
- second_title.title(" Using ML to explain your menu items ... ")
138
-
139
- if "input_image" in sst:
140
-
141
- image = sst["input_image"]
142
-
143
- msg1 = st.empty()
144
- msg1.write("Pre-processing and extracting text out of your image ....")
145
- st_filter = time.perf_counter()
146
-
147
- # Call the extract_filter_img function
148
- filtered_text = await extract_filter_img(image)
149
- en_filter = time.perf_counter()
150
 
151
- num_items_detected = len(filtered_text)
152
-
153
- if num_items_detected == 0:
154
- st.write("We couldn't detect any menu items ( indian for now ) from your image, please try a different image.")
155
-
156
- elif num_items_detected > 0:
157
- st.write(f"Detected {num_items_detected} menu items from your input image ... ")
158
-
159
- msg2 = st.empty()
160
- msg2.write("All pre-processing done, transcribing your menu items now ....")
161
- st_trans_llm = time.perf_counter()
162
-
163
- await dist_llm_inference(filtered_text)
164
-
165
- msg3 = st.empty()
166
- msg3.write("Done transcribing ... ")
167
- en_trans_llm = time.perf_counter()
168
-
169
- msg1.empty(); msg2.empty(); msg3.empty()
170
- st.success("Image processed successfully! " )
171
-
172
- if DEBUG_MODE:
173
- filter_time_sec = en_filter - st_filter
174
- llm_time_sec = en_trans_llm - st_trans_llm
175
- total_time_sec = filter_time_sec + llm_time_sec
176
-
177
- st.write("Time took to extract and filter text {}".format(filter_time_sec))
178
- st.write("Time took to summarize by LLM {}".format(llm_time_sec))
179
- st.write('Overall time taken in seconds: {}'.format(total_time_sec))
180
-
181
 
182
- st.button("translate another",
183
- on_click=navigate_to,
184
- args=("Home",))
185
-
186
- else:
187
- st.write("Looks like image upload failed, please try uploading it again ... ")
188
-
189
-
190
- async def main():
191
- """
192
- Function that toggles between pages based on state variables.
193
 
194
- Parameters:
195
- None
196
- Returns:
197
- None
198
- """
199
- if sst["page"] == "Home":
200
- await main_page()
201
  elif sst["page"] == "Inference":
202
- await model_inference()
203
 
204
  asyncio.run(main())
 
1
+ from utils import navigate_to
2
+ from pages import manual_input_page, image_input_page, model_inference_page
3
+
4
  import streamlit as st
5
  from streamlit import session_state as sst
 
6
  import asyncio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ #TODO: Fix model inference and post processing function befor emoving ot production.
9
  # Initialize session state variable to start with home page
10
  if "page" not in sst:
11
  sst["page"] = "Home"
12
 
13
+ # function to remove all sesion variables from sst, except page.
14
+ def reset_sst():
15
+ for key in list(sst.keys()):
16
+ if key != "page":
17
+ sst.pop(key, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Landing page function
20
+ async def landing_page():
 
21
 
22
+ st.title("We will explain your menu like never before!")
23
+ st.write("\n")
24
+ st.write("\n")
25
+ st.write("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ c1, c2= st.columns(2)
28
 
29
+ with c1:
30
+ # Navigate to manual input page if user clicks on the button
31
+ st.button("Enter Items Manually", on_click=navigate_to, args=("ManualInput",))
32
 
33
+ with c2:
34
+ # Navigate to image input page if user clicks on the button
35
+ st.button("Upload Items from Image", on_click=navigate_to, args=("ImageInput",))
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
38
 
39
+ # Main function to handle navigation
40
+ async def main():
 
 
 
 
 
 
 
 
 
 
 
 
41
  """
42
+ Main function that handles the navigation logic based on the current page.
 
43
 
 
 
44
  Returns:
45
  None
 
46
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Navigation logic
49
+ if sst["page"] == "Home":
50
+ reset_sst() # reset all session state variables before navigating to the landing page
51
+ await landing_page() # Call the landing page function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ elif sst["page"] == "ManualInput":
54
+ reset_sst() # reset all session state variables before navigating to the landing page
55
+ await manual_input_page() # Call the manual input page function
56
+
57
+ elif sst["page"] == "ImageInput":
58
+ reset_sst() # reset all session state variables before navigating to the landing page
59
+ await image_input_page() # Call the image input page function
 
 
 
 
60
 
 
 
 
 
 
 
 
61
  elif sst["page"] == "Inference":
62
+ await model_inference_page() # Call the model inference page function
63
 
64
  asyncio.run(main())
inference/config.py CHANGED
@@ -1,33 +1,23 @@
1
- INSTRUCTION_PROMPT = """
2
- The following text contains examples of three items and their corresponding explanations in the required format.\n
3
 
4
- Item -> palak paneer.\n
5
- Explanation -> Major Ingredients here: paneer ( a.k.a cottage cheese ) , palak ( spinach ).\n
6
- How it is made: It's a savory item, made like a gravy; usually made by sauteing spices and mixing saute with boiled paneer and palak.\n
7
- It goes well with: White basmati rice or Indian flat bread.\n
8
- Allergens: Paneer may cause digestive discomfort and intolerance to some.\n
9
- Food Category: Vegetarian, Vegans may not like it, as paneer is usually made from cow milk.
10
 
 
 
11
 
12
- Item -> rumali roti.\n
13
- Explanation -> Major Ingredients here: roti.\n
14
- How it is made: A small soft bread, made to size of a napkin ( a.k.a 'rumal' in hindi ); usually made with a combination of whole wheat and all purpose flour.\n
15
- It goes well with: Most indian gravies such as palak paneer, tomato curry etc.\n
16
- Allergens: May contain gluten, which is known to cause digestive discomfort and intolerance to some.\n
17
- Food Category: Vegetarian, Vegan.
18
 
19
 
20
- Item -> nizami handi.\n
21
- Explanation -> Major Ingredients here: Different veggies, makhani sauce (skimmed milk, tomato and cashew paste , indian spices), combination of nuts.\n
22
- How it is made: Makhani sauce is added to onion-tomato based paste and bought to a boil; a Medley of veggies and gently flavored whole spices are added and boiled for small time.\n
23
- It goes well with: Different kinds of indian flat breads, white basmati and sonamasoori rice.\n
24
- Allergens: Presence of nuts, butter cream and makhani sauce are known to cause digestive discomfort and intolerance to some.\n
25
- Food Category: Usually vegetarian, may include chicken or animal meat sometimes, please check with hotel.
 
26
 
 
27
 
28
- Based on Item and explanation pairs provided above, provide similar explanation ('Major Ingredients', 'How is it made', 'It goes well with', 'Allergens' and 'Food Category') to the below item.\n
29
- Item ->
30
- """
31
-
32
- DEBUG_MODE = False
33
- DEVICE = 'cpu'
 
1
+ import torch
2
+ import re
3
 
4
+ model_inf_inp_prompt = "INSTRUCTION: given food item name, explain these things:(major ingredients,making process,portion & spicy/sweet,pairs with,allergens,food type(veg/non-veg/vegan)). ensure to get allergens and food category factually correct.Item Name: {} "
5
+ header_pattern = r'Item Name: (.*?)\. Major Ingredients: (.*?)\. Making Process: (.*?)\. Portion and Spice Level: (.*?)\. Pairs With: (.*?)\. Allergens: (.*?)\. Food Type: (.*?)\.\s*</s>'
6
+ dots_pattern = re.compile(r'\.{3,}')
 
 
 
7
 
8
+ DEBUG_MODE = True
9
+ model_name = "AmithAdiraju1694/gpt-neo-125M_menuitemexp"
10
 
 
 
 
 
 
 
11
 
12
 
13
+ def get_device():
14
+ if torch.cuda.is_available():
15
+ device = torch.device("cuda")
16
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}") #get the name of the GPU being used.
17
+ else:
18
+ device = torch.device("cpu")
19
+ print("Using CPU")
20
 
21
+ return device
22
 
23
+ DEVICE = get_device()
 
 
 
 
 
inference/preprocess_image.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
  import numpy as np
3
- from typing import List, Tuple, Optional, AnyStr
4
  import nltk
5
  nltk.download("stopwords")
6
  nltk.download('punkt')
@@ -53,11 +53,64 @@ def image_to_np_arr(image) -> np.array:
53
  return np.array(image)
54
 
55
  async def process_extracted_text(raw_extrc_text: List[Tuple]) -> List[AnyStr]:
56
-
 
 
 
 
 
 
 
 
 
57
  output_texts = []
58
  for _, extr_text, _ in raw_extrc_text:
59
  # remove all numbers, special characters from a string
60
  prcsd_txt = preprocess_text(extr_text)
61
- if len(prcsd_txt.split(" ") ) >= 2: output_texts.append(prcsd_txt)
 
62
 
63
- return output_texts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import numpy as np
3
+ from typing import List, Tuple, Optional, AnyStr, Dict
4
  import nltk
5
  nltk.download("stopwords")
6
  nltk.download('punkt')
 
53
  return np.array(image)
54
 
55
  async def process_extracted_text(raw_extrc_text: List[Tuple]) -> List[AnyStr]:
56
+ """
57
+ Function that processes extracted text by removing numbers and special characters,
58
+ and filters out text with less than 2 words.
59
+
60
+ Parameters:
61
+ raw_extrc_text: List[Tuple], required -> A list of tuples containing extracted text.
62
+
63
+ Returns:
64
+ List[AnyStr] -> A list of processed text strings.
65
+ """
66
  output_texts = []
67
  for _, extr_text, _ in raw_extrc_text:
68
  # remove all numbers, special characters from a string
69
  prcsd_txt = preprocess_text(extr_text)
70
+ if len(prcsd_txt.split(" ")) >= 2:
71
+ output_texts.append(prcsd_txt)
72
 
73
+ return output_texts
74
+
75
+ def post_process_gen_outputs(gen_output: List[str], header_pattern: str, dots_pattern:str) -> List[Dict]:
76
+
77
+ # Define the regular expression pattern to match section names and placeholders
78
+ headers = ["Item Name", "Major Ingredients", "Making Process", "Portion and Spice Level", "Pairs With", "Allergens", "Food Type"]
79
+
80
+ # Function to clean the strings
81
+ def clean_string(input_string):
82
+ parts = input_string.split(',')
83
+ cleaned_parts = [part.strip() for part in parts if part.strip()]
84
+ return ', '.join(cleaned_parts)
85
+
86
+ for i in range(len(gen_output)):
87
+ # Find all matches
88
+ matches = re.findall(header_pattern, gen_output[i])
89
+
90
+ # Since re.findall returns a list of tuples, we need to extract the first tuple
91
+ if matches:
92
+ result = dict(zip(headers,matches[0]))
93
+ result['Major Ingredients'] = clean_string(result['Major Ingredients'])
94
+
95
+ # if any of dictionary values strings are emtpy, replace it with string "Sorry, can't explain this."
96
+ for k in result.keys():
97
+ if len(result[k]) < 3 or any(header in result[k] for header in headers):
98
+ result[k] = "Sorry, can't explain this."
99
+
100
+ gen_output[i] = result
101
+
102
+ else:
103
+ if headers[1] in gen_output[i]:
104
+
105
+ gen_output[i] = {"May contain misleading explanation":
106
+ dots_pattern.sub('' ,
107
+ gen_output[i].split(headers[1]
108
+ )[1].strip().replace('</s>', '')
109
+ )
110
+ }
111
+ else:
112
+ gen_output[i] = {"Sorry, can't explain this item": "NA"}
113
+
114
+ gen_output[i].pop('Item Name', None)
115
+ return gen_output
116
+
inference/translate.py CHANGED
@@ -2,29 +2,50 @@ import streamlit as st
2
 
3
  from inference.preprocess_image import (
4
  image_to_np_arr,
5
- process_extracted_text
 
6
  )
7
 
8
- from inference.config import INSTRUCTION_PROMPT, DEVICE
 
 
 
 
 
 
9
  from typing import List, Tuple, Optional, AnyStr, Dict
10
- from transformers import T5Tokenizer, T5ForConditionalGeneration
11
  import easyocr
12
  import time
13
 
14
  use_gpu = True
15
- if DEVICE == 'cpu': use_gpu = False
16
 
17
  @st.cache_resource
18
  def load_models(item_summarizer: AnyStr) -> Tuple:
 
 
 
 
 
 
 
 
 
 
 
 
19
  text_extractor = easyocr.Reader(['en'],
20
  gpu = use_gpu
21
  )
22
- tokenizer = T5Tokenizer.from_pretrained(item_summarizer)
23
- model = T5ForConditionalGeneration.from_pretrained(item_summarizer)
 
 
24
 
25
  return (text_extractor, tokenizer, model)
26
 
27
- text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = "google/flan-t5-large")
28
 
29
 
30
  # Define your extract_filter_img function
@@ -78,20 +99,24 @@ async def extract_filter_img(image) -> Dict:
78
 
79
  def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:
80
 
81
- prompt_item = INSTRUCTION_PROMPT + " " + menu_text + """
82
-
83
-
84
- """
85
  input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
86
 
87
  outputs = item_summarizer.generate(input_ids,
88
- max_new_tokens = 512
 
 
 
 
 
 
 
 
89
  )
90
 
91
- return item_tokenizer.decode(
92
- outputs[0],
93
- skip_special_tokens = True
94
- )
95
 
96
  def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
97
  return extrc_str
 
2
 
3
  from inference.preprocess_image import (
4
  image_to_np_arr,
5
+ process_extracted_text,
6
+ post_process_gen_outputs
7
  )
8
 
9
+ from inference.config import (
10
+ model_inf_inp_prompt,
11
+ header_pattern,
12
+ dots_pattern,
13
+ DEVICE,
14
+ model_name
15
+ )
16
  from typing import List, Tuple, Optional, AnyStr, Dict
17
+ from transformers import AutoTokenizer, AutoModelForCausalLM
18
  import easyocr
19
  import time
20
 
21
  use_gpu = True
22
+ if DEVICE.type == 'cpu': use_gpu = False
23
 
24
  @st.cache_resource
25
  def load_models(item_summarizer: AnyStr) -> Tuple:
26
+
27
+ """
28
+ Function to load the models required for the inference process. Cached to avoid loading the models, every time the function is called.
29
+
30
+ Parameters:
31
+ item_summarizer: str, required -> The LLM model name to be used for item summarization.
32
+
33
+ Returns:
34
+ Tuple -> Tuple containing the required models for the inference process.
35
+ """
36
+
37
+ # model to extract text from image
38
  text_extractor = easyocr.Reader(['en'],
39
  gpu = use_gpu
40
  )
41
+
42
+ # tokenizer and model to generate item summary
43
+ tokenizer = AutoTokenizer.from_pretrained(item_summarizer)
44
+ model = AutoModelForCausalLM.from_pretrained(item_summarizer)
45
 
46
  return (text_extractor, tokenizer, model)
47
 
48
+ text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = model_name)
49
 
50
 
51
  # Define your extract_filter_img function
 
99
 
100
  def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:
101
 
102
+ prompt_item = model_inf_inp_prompt.format(menu_text)
 
 
 
103
  input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
104
 
105
  outputs = item_summarizer.generate(input_ids,
106
+ max_new_tokens = 512,
107
+ num_beams = 4,
108
+ pad_token_id = item_tokenizer.pad_token_id,
109
+ eos_token_id = item_tokenizer.eos_token_id,
110
+ bos_token_id = item_tokenizer.bos_token_id
111
+ )
112
+
113
+ prediction = item_tokenizer.batch_decode(outputs,
114
+ skip_special_tokens=False
115
  )
116
 
117
+ postpro_output = post_process_gen_outputs( prediction, header_pattern, dots_pattern )[0]
118
+
119
+ return postpro_output
 
120
 
121
  def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
122
  return extrc_str
pages.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit import session_state as sst
3
+
4
+
5
+ from utils import navigate_to
6
+ from inference.config import DEBUG_MODE
7
+
8
+ from inference.translate import extract_filter_img, transcribe_menu_model,classify_menu_text
9
+ from inference.preprocess_image import preprocess_text
10
+
11
+ import os
12
+ import time
13
+ import pandas as pd
14
+ from PIL import Image
15
+ from typing import List
16
+ import json
17
+ from concurrent.futures import ThreadPoolExecutor, as_completed
18
+
19
+ # Setting workers to be 70% of all available virtual cpus in system
20
+ cpu_count = os.cpu_count()
21
+ pool = ThreadPoolExecutor(max_workers=int(cpu_count*0.7) )
22
+
23
+ # Function that handles logic of explaining menu items from manual input
24
+ async def manual_input_page():
25
+
26
+ """
27
+ Function that takes text input from user in input box of streamlit, user can add multiple text boxes and submit finally.
28
+
29
+ Parameters:
30
+ None
31
+
32
+ Returns:
33
+ List[str]: List of strings, containing item names of a menu in english.
34
+ """
35
+
36
+ st.write("This is the Manual Input Page.")
37
+ st.write("Once done, click on 'Explain My Menu' button to get explanations for each item ... ")
38
+
39
+ inp_texts = []
40
+ num_text_boxes = st.number_input("Number of text boxes", min_value=1, step=1)
41
+ for i in range(num_text_boxes):
42
+ text_box = st.text_input(f"Food item {i+1}")
43
+ if text_box:
44
+ inp_texts.append(text_box)
45
+
46
+ if len(inp_texts) > 0:
47
+
48
+ # Show user submit button only if they have entered some text and set text in session state
49
+ sst["user_entered_items"] = inp_texts
50
+ st.button("Explain My Menu",on_click=navigate_to,args=("Inference",))
51
+
52
+ else:
53
+ st.write("Please enter some items to proceed ...")
54
+
55
+
56
+ st.button("Go back Home", on_click=navigate_to, args=("Home",))
57
+
58
+
59
+ # Function that handles logic of explaining menu items from image uploads
60
+ async def image_input_page():
61
+ """
62
+ Function that contains content of main page i.e., image uploader and submit button to navigate to next page.
63
+ Upon submit , control goes to model inference 'page'.
64
+
65
+ Parameters:
66
+ None
67
+
68
+ Returns:
69
+ None
70
+ """
71
+
72
+ st.write("This is the Image Input Page.")
73
+
74
+ # Streamlit function to upload an image from any device
75
+ uploaded_file = st.file_uploader("Choose an image...",
76
+ type=["jpg", "jpeg", "png"])
77
+
78
+ # Remove preivous states' value of input image if it exists
79
+ sst.pop('input_image', None)
80
+
81
+ # Submit button
82
+ if uploaded_file is not None:
83
+ image = Image.open(uploaded_file)
84
+
85
+ # Only show if user wants to see
86
+ if st.checkbox('Show Uploaded Image'):
87
+ st.image(image,
88
+ caption='Uploaded Image',
89
+ use_column_width=True)
90
+
91
+ sst["input_image"] = image
92
+
93
+ # Show user submit button only if they have uploaded an image
94
+ st.button("Translate My Menu",
95
+ on_click = navigate_to,
96
+ args = ("Inference",))
97
+
98
+
99
+ # Warning message to user
100
+ st.info("""This application is for education purposes only. It uses AI, hence it's dietary
101
+ recommendations are not to be taken as medical advice, author doesn't bear responsibility
102
+ for incorrect dietary recommendations. Please proceed with caution.
103
+ """)
104
+
105
+ # if user wants to go back, make sure to reset the session state
106
+ st.button("Go back Home", on_click=navigate_to, args=("Home",))
107
+
108
+
109
+ # Function that handles model inference
110
+ async def model_inference_page():
111
+
112
+ """
113
+ Function that pre-processes input text from state variables, does concurrent inference
114
+ and toggles state between pages if needed.
115
+
116
+ Parameters:
117
+ None
118
+ Returns:
119
+ None
120
+
121
+ """
122
+
123
+ second_title = st.empty()
124
+ second_title.title(" Using ML to explain your menu items ... ")
125
+
126
+ # User can either upload an image or enter text manually, we check for both
127
+ if "input_image" in sst:
128
+ image = sst["input_image"]
129
+
130
+ msg1 = st.empty()
131
+ msg1.write("Pre-processing and extracting text out of your image ....")
132
+ # Call the extract_filter_img function
133
+ filtered_text = await extract_filter_img(image)
134
+ num_items_detected = len(filtered_text)
135
+
136
+
137
+ if "user_entered_items" in sst:
138
+ user_text = sst["user_entered_items"]
139
+ st.write("Pre-processing and filtering text from user input ....")
140
+
141
+ filtered_text = [preprocess_text(ut) for ut in user_text]
142
+
143
+ num_items_detected = len(filtered_text)
144
+
145
+
146
+ # irrespective of source of user entry , we check if we have any items to process
147
+ if num_items_detected == 0:
148
+ st.write("We couldn't detect any menu items ( indian for now ) from your image, please try a different image by going back.")
149
+
150
+ elif num_items_detected > 0:
151
+ st.write(f"Detected {num_items_detected} menu items from your input image ... ")
152
+
153
+ msg2 = st.empty()
154
+ msg2.write("All pre-processing done, transcribing your menu items now ....")
155
+ st_trans_llm = time.perf_counter()
156
+
157
+ await dist_llm_inference(filtered_text)
158
+
159
+ msg3 = st.empty()
160
+ msg3.write("Done transcribing ... ")
161
+ en_trans_llm = time.perf_counter()
162
+
163
+ msg2.empty(); msg3.empty()
164
+ st.success("Image processed successfully! " )
165
+
166
+ # Some basic stats for debug mode
167
+ if DEBUG_MODE:
168
+ llm_time_sec = en_trans_llm - st_trans_llm
169
+ st.write("Time took to summarize by LLM {}".format(llm_time_sec))
170
+
171
+
172
+ # If user clicked in "translate_another" button reset all session state variables and go back to home
173
+ st.button("Go back Home", on_click=navigate_to, args=("Home",))
174
+
175
+
176
+ # Function that performs LLM inference on a single item
177
+ async def dist_llm_inference(inp_texts: List[str]) -> None:
178
+
179
+ """
180
+ Function that performs concurrent LLM inference using threadpool. It displays
181
+ results of those threads that are done with execution, as a dynamic row to streamlit table, rather than
182
+ waiting for all threads to be done.
183
+
184
+ Parameters:
185
+ inp_texts: List[str], required -> List of strings, containing item names of a menu in english.
186
+
187
+ Returns:
188
+ None
189
+ """
190
+
191
+ df = pd.DataFrame([('ITEM NAME', 'EXPLANATION')]
192
+ )
193
+
194
+ sl_table = st.table(df)
195
+ tp_futures = { pool.submit(transcribe_menu_model, mi): mi for mi in inp_texts }
196
+
197
+ for tpftr in as_completed(tp_futures):
198
+
199
+ item = tp_futures[tpftr]
200
+
201
+ try:
202
+ exp = tpftr.result()
203
+
204
+
205
+ sl_table.add_rows([(item,
206
+ str(exp ))
207
+ ]
208
+ )
209
+
210
+ except Exception as e:
211
+ print("Could not add a new row dynamically, because of this error:", e)
212
+
213
+ return
214
+
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from streamlit import session_state as sst
3
+ def navigate_to(page: str) -> None:
4
+ """
5
+ Function to set the current page in the state of streamlit. A helper for
6
+ simulating navigation in streamlit.
7
+
8
+ Parameters:
9
+ page: str, required.
10
+
11
+ Returns:
12
+ None
13
+ """
14
+
15
+ sst["page"] = page