SumanthKarnati's picture
Update app.py
bc81463
import gradio as gr
import os
import nltk
from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTImageProcessor, pipeline
import torch
from PIL import Image
from nltk.corpus import stopwords
from io import BytesIO
nltk.download('stopwords')
model = VisionEncoderDecoderModel.from_pretrained("SumanthKarnati/Image2Ingredients")
model.eval()
feature_extractor = ViTImageProcessor.from_pretrained('nlpconnect/vit-gpt2-image-captioning')
tokenizer = AutoTokenizer.from_pretrained('nlpconnect/vit-gpt2-image-captioning')
generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
stop_words = set(stopwords.words('english'))
def remove_stop_words(word_list):
return [word for word in word_list if word not in stop_words]
def predict_step(image_files, model, feature_extractor, tokenizer, device, gen_kwargs):
images = []
for image_file in image_files:
if image_file is not None:
image = Image.open(image_file.name)
if image.mode != "RGB":
image = image.convert(mode="RGB")
images.append(image)
if not images:
return None
inputs = feature_extractor(images=images, return_tensors="pt")
inputs.to(device)
output_ids = model.generate(inputs["pixel_values"], **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
def process_image(image):
preds = predict_step([image], model, feature_extractor, tokenizer, device, gen_kwargs)
preds = preds[0].split('-')
preds = [x for x in preds if not any(c.isdigit() for c in x)]
preds = list(filter(None, preds))
preds = list(dict.fromkeys(preds))
preds = remove_stop_words(preds)
preds_str = ', '.join(preds)
prompt = f"You are a knowledgeable assistant that provides nutritional advice based on a list of ingredients. The identified ingredients are: {preds_str}. Note that some ingredients may not make sense, so use the ones that do. Can you provide a nutritional analysis and suggestions for improvement?"
suggestions = generator(prompt, do_sample=True, min_length=200)
suggestions = suggestions[0]['generated_text'][len(prompt):]
return preds, suggestions
iface = gr.Interface(fn=process_image, inputs="image", outputs=["text", "text"])
iface.launch()