|
from typing import Dict, List, Any |
|
from PIL import Image |
|
from tfing import TFIng |
|
from tfport import TFPort, get_look_ahead_mask, get_padding_mask |
|
|
|
import os |
|
import json |
|
import tensorflow as tf |
|
import numpy as np |
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
embed_dim = 256 |
|
num_layers = 3 |
|
seq_length = 20 |
|
hidden_dim = 1024 |
|
num_heads = 8 |
|
self.crop_size = (224, 224) |
|
self.img_size = 256 |
|
self.nutr_names = ('energy', 'fat', 'protein', 'carbs') |
|
with open(os.path.join(path, "ingredients_metadata.json"), encoding='UTF-8') as f: |
|
self.ingredients = json.load(f) |
|
self.ing_names = {ing['name']: int(ing_id) for ing_id, ing in self.ingredients.items()} |
|
self.vocab_size = len(self.ingredients) + 3 |
|
self.seq_length = seq_length |
|
|
|
self.tfing = TFIng( |
|
self.crop_size, |
|
embed_dim, |
|
num_layers, |
|
seq_length, |
|
hidden_dim, |
|
num_heads, |
|
self.vocab_size |
|
) |
|
self.tfing.compile() |
|
self.tfing((tf.zeros((1, 224, 224, 3)), tf.zeros((1, seq_length)))) |
|
self.tfing.load_weights(os.path.join(path, 'tfing.h5')) |
|
|
|
self.tfport = TFPort( |
|
self.crop_size, |
|
embed_dim, |
|
num_layers, |
|
num_layers, |
|
seq_length, |
|
seq_length, |
|
hidden_dim, |
|
num_heads, |
|
self.vocab_size |
|
) |
|
self.tfport.compile() |
|
self.tfport((tf.zeros((1, 224, 224, 3)), tf.zeros((1, seq_length)), tf.zeros((1, seq_length)))) |
|
self.tfport.load_weights(os.path.join(path, 'tfport.h5')) |
|
|
|
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]: |
|
image = tf.keras.preprocessing.image.img_to_array(inputs) |
|
height = tf.shape(image)[0] |
|
width = tf.shape(image)[1] |
|
if width > height: |
|
image = tf.image.resize(image, (self.img_size, int(float(self.img_size * width) / float(height)))) |
|
else: |
|
image = tf.image.resize(image, (int(float(self.img_size * height) / float(width)), self.img_size)) |
|
|
|
image = tf.keras.applications.inception_v3.preprocess_input(image) |
|
image = tf.keras.layers.CenterCrop(*self.crop_size)(image) |
|
prediction = self.predict(image) |
|
return [ |
|
{ |
|
"label": prediction['ingredients'][i], |
|
"score": prediction['portions'][i] |
|
} |
|
for i in range(len(prediction['ingredients'])) |
|
] |
|
|
|
def encode_image(self, image): |
|
encoder_out = self.tfing.encoder(image) |
|
encoder_out = self.tfing.conv(encoder_out) |
|
encoder_out = tf.reshape( |
|
encoder_out, |
|
(tf.shape(encoder_out)[0], -1, tf.shape(encoder_out)[3]) |
|
) |
|
return encoder_out |
|
|
|
def encode_ingredients(self, ingredients, padding_mask): |
|
return self.tfport.ingredient_encoder(ingredients, padding_mask) |
|
|
|
def decode_ingredients(self, encoded_img, decoder_in): |
|
decoder_outputs = self.tfing.decoder(decoder_in, encoded_img) |
|
output = self.tfing.linear(decoder_outputs) |
|
return output + self.tfing.get_replacement_mask(decoder_in) |
|
|
|
def decode_portions(self, encoded_img, encoded_ingr, decoder_in, padding_mask): |
|
encoder_outputs = tf.concat([encoded_img, encoded_ingr], axis=1) |
|
img_mask = tf.ones((tf.shape(encoded_img)[0], 1, tf.shape(encoded_img)[1]), dtype=tf.int32) |
|
padding_mask = tf.concat([img_mask, padding_mask], axis=2) |
|
look_ahead_mask = get_look_ahead_mask(decoder_in) |
|
|
|
x = self.tfport.portion_embedding(decoder_in) |
|
for i in range(len(self.tfport.decoder_layers)): |
|
x = self.tfport.decoder_layers[i](x, encoder_outputs, look_ahead_mask, padding_mask=padding_mask) |
|
x = self.tfport.linear(x) |
|
return tf.squeeze(x) |
|
|
|
def predict_ingredients(self, encoded_img, known_ing=None): |
|
predicted = np.zeros((1, self.seq_length + 1), dtype=int) |
|
predicted[0, 0] = self.vocab_size - 2 |
|
start_index = 0 |
|
if known_ing: |
|
predicted[0, 1:len(known_ing) + 1] = known_ing |
|
start_index = len(known_ing) |
|
for i in range(start_index, self.seq_length): |
|
decoded = self.decode_ingredients(encoded_img, predicted[:, :-1]) |
|
next_token = int(np.argmax(decoded[0, i])) |
|
predicted[0, i + 1] = next_token |
|
if next_token == self.vocab_size - 1: |
|
return predicted[0, 1:] |
|
if i == self.seq_length - 1: |
|
predicted[0, i + 1] = self.vocab_size - 1 |
|
return predicted[0, 1:] |
|
|
|
def predict_portions(self, encoded_image, ingredients): |
|
predicted = np.zeros((1, self.seq_length + 1), dtype=float) |
|
predicted[0, 0] = -1 |
|
padding_mask = get_padding_mask(ingredients) |
|
encoded_ingr = self.encode_ingredients(ingredients, padding_mask) |
|
for i in range(self.seq_length): |
|
if ingredients[0, i] == self.vocab_size - 1: |
|
return predicted[0, 1:] |
|
next_proportion = float( |
|
self.decode_portions( |
|
encoded_image, |
|
encoded_ingr, |
|
predicted[:, :-1], |
|
padding_mask |
|
)[i] |
|
) |
|
predicted[0, i + 1] = next_proportion |
|
return predicted[0, 1:] |
|
|
|
def process_ingredients(self, ingredients): |
|
processed = [] |
|
for ingredient in ingredients.split('\n'): |
|
stripped = ingredient.strip() |
|
if stripped == '.': |
|
return processed, True |
|
if stripped in self.ing_names: |
|
processed.append(self.ing_names[stripped]) |
|
return processed, False |
|
|
|
def predict(self, image, known_ing=None): |
|
encoded_image = self.encode_image(image[tf.newaxis, :]) |
|
known_ing, skip_ing = self.process_ingredients(known_ing)\ |
|
if known_ing else (None, False) |
|
if not skip_ing: |
|
ingredients = self.predict_ingredients(encoded_image, known_ing=known_ing) |
|
else: |
|
ingredients = known_ing[:self.seq_length - 1] |
|
ingredients.append(self.vocab_size - 1) |
|
ingredients = np.pad(ingredients, (0, self.seq_length - len(ingredients))) |
|
readable_ingredients = [ |
|
self.ingredients[str(token)]['name'] for token in ingredients |
|
if token != 0 and token != self.vocab_size - 1 |
|
] |
|
portions = self.predict_portions(encoded_image, ingredients[tf.newaxis, :])\ |
|
if len(readable_ingredients) > 1 else [100] |
|
portions_slice = portions[:len(readable_ingredients)] |
|
scale = 100 / sum(portions_slice) |
|
return { |
|
'ingredients': readable_ingredients, |
|
'portions': [portion * scale for portion in portions_slice], |
|
'nutrition': { |
|
name: sum( |
|
self.ingredients[str(ingredients[i])][name] * portions[i] / 100 |
|
for i in range(len(readable_ingredients)) |
|
) for name in self.nutr_names |
|
} |
|
} |