############################################################################################################################# | |
# Filename : app.py | |
# Description: A Streamlit application to generate recipes, given an image of a food and an image of ingredients. | |
# Author : Georgios Ioannou | |
# | |
# Copyright Β© 2024 by Georgios Ioannou | |
############################################################################################################################# | |
# Import libraries. | |
import openai # gpt-3.5-turbo model inference. | |
import os # Load environment variable(s). | |
import requests # Send HTTP GET request to Hugging Face models for inference. | |
import streamlit as st # Build the GUI of the application. | |
import torch # Load Salesforce/blip model(s) on GPU. | |
from langchain.chat_models import ChatOpenAI # Access to OpenAI gpt-3.5-turbo model. | |
from langchain.chains import LLMChain # Chain to run queries against LLMs. | |
# A prompt template. It accepts a set of parameters from the user that can be used to generate a prompt for a language model. | |
from langchain.prompts import PromptTemplate | |
from PIL import Image # Open and identify a given image file. | |
from transformers import BlipProcessor, BlipForQuestionAnswering # VQA model inference. | |
############################################################################################################################# | |
# Load environment variable(s). | |
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
############################################################################################################################# | |
# Function to apply local CSS. | |
def local_css(file_name): | |
with open(file_name) as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
############################################################################################################################# | |
# Load the Visual Question Answering (VQA) model directly. | |
# Using transformers. | |
def load_model(): | |
blip_processor_base = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
blip_model_base = BlipForQuestionAnswering.from_pretrained( | |
"Salesforce/blip-vqa-base" | |
) | |
# Backup model. | |
# blip_processor_large = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
# blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
# return blip_processor_large, blip_model_large | |
return blip_processor_base, blip_model_base | |
############################################################################################################################# | |
# General function for any Salesforce/blip model(s). | |
# VQA model. | |
def generate_answer_blip(processor, model, image, question): | |
# Prepare image + question. | |
inputs = processor(images=image, text=question, return_tensors="pt") | |
generated_ids = model.generate(**inputs, max_length=50) | |
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return generated_answer | |
############################################################################################################################# | |
# Generate answer from the Salesforce/blip model(s). | |
# VQA model. | |
def generate_answer(image, question): | |
answer_blip_base = generate_answer_blip( | |
processor=blip_processor_base, | |
model=blip_model_base, | |
image=image, | |
question=question, | |
) | |
# answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question) | |
# return answer_blip_large | |
return answer_blip_base | |
############################################################################################################################# | |
# Detect ingredients on an image. | |
# Object detection model. | |
def generate_ingredients(image): | |
API_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50" | |
headers = {"Authorization": f"Bearer {HUGGINGFACEHUB_API_TOKEN}"} | |
with open(image, "rb") as img: | |
data = img.read() | |
response = requests.post(url=API_URL, data=data, headers=headers) | |
ingredients = response.json() | |
return ingredients | |
############################################################################################################################# | |
# Return the recipe generated by the model for the food and ingredients detected by the previous models. | |
# Using Langchain. | |
def generate_recipe(food, ingredients, chef): | |
# Model used here: "gpt-3.5-turbo". | |
# The template can be customized to meet one's needs such as: | |
# Generate a recipe, generate a scenario, and generate lyrics of a song. | |
template = """ | |
You are a chef. | |
You must sound like {chef}. | |
You must make use of these ingredients: {ingredients}. | |
Generate a detailed recipe step by step based on the above constraints for this food: {food}. | |
""" | |
prompt = PromptTemplate( | |
template=template, input_variables=["food", "ingredients", "chef"] | |
) | |
recipe_llm = LLMChain( | |
llm=ChatOpenAI( | |
model_name="gpt-3.5-turbo", temperature=0 | |
), # Increasing the temperature, the model becomes more creative and takes longer for inference. | |
prompt=prompt, | |
verbose=True, # Print intermediate values to the console. | |
) | |
recipe = recipe_llm.predict( | |
food=food, ingredients=ingredients, chef=chef | |
) # Format prompt with kwargs and pass to LLM. | |
return recipe | |
############################################################################################################################# | |
# Return the speech generated by the model for the recipe. | |
# Using inference api. | |
def generate_speech(response): | |
# Model used here: "facebook/mms-tts-eng". | |
# Backup model: "espnet/kan-bayashi_ljspeech_vits. | |
# API_URL = ( | |
# "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_vits" | |
# ) | |
API_URL = "https://api-inference.huggingface.co/models/facebook/mms-tts-eng" | |
headers = {"Authorization": f"Bearer {HUGGINGFACEHUB_API_TOKEN}"} | |
payload = {"inputs": response} | |
response = requests.post(url=API_URL, headers=headers, json=payload) | |
with open("audio.flac", "wb") as file: | |
file.write(response.content) | |
############################################################################################################################# | |
# Conversation with OpenAI gpt-3.5-turbo model. | |
def get_completion_from_messages(messages, model="gpt-3.5-turbo", temperature=0): | |
response = openai.ChatCompletion.create( | |
model=model, | |
messages=messages, | |
temperature=temperature, # This is the degree of randomness of the model's output. | |
) | |
# print(str(response.choices[0].message)) | |
return response.choices[0].message["content"] | |
############################################################################################################################# | |
# Page title and favicon. | |
st.set_page_config(page_title="ChefBot | Recipe Generator/Assistant", page_icon="π΄") | |
############################################################################################################################# | |
# Load the Salesforce/blip model directly. | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
# elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
# device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
blip_processor_base, blip_model_base = load_model() | |
blip_model_base.to(device) | |
############################################################################################################################# | |
# Define the chefs for the dropdown menu. | |
chefs = [ | |
"Gordon Ramsay", | |
"Donald Trump", | |
"Cardi B", | |
] | |
############################################################################################################################# | |
# Main function to create the Streamlit web application. | |
def main(): | |
try: | |
##################################################################################################################### | |
# Load CSS. | |
local_css("styles/style.css") | |
##################################################################################################################### | |
# Title. | |
title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem"> | |
ChefBot - Recipe Generator/Assistant</h1>""" | |
st.markdown(title, unsafe_allow_html=True) | |
# st.title("ChefBot - Automated Recipe Assistant") | |
##################################################################################################################### | |
# Subtitle. | |
subtitle = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem"> | |
CUNY Tech Prep Tutorial 2</h2>""" | |
st.markdown(subtitle, unsafe_allow_html=True) | |
##################################################################################################################### | |
# Image. | |
image = "./ctp.png" | |
left_co, cent_co, last_co = st.columns(3) | |
with cent_co: | |
st.image(image=image) | |
##################################################################################################################### | |
# Heading 1. | |
heading1 = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 1rem"> | |
Food</h3>""" | |
st.markdown(heading1, unsafe_allow_html=True) | |
##################################################################################################################### | |
# Upload an image. | |
uploaded_file_food = st.file_uploader( | |
label="Choose an image:", | |
key="food", | |
help="An image of the food that you want a recipe for.", | |
) | |
##################################################################################################################### | |
if uploaded_file_food is not None: | |
# Display the uploaded image. | |
bytes_data = uploaded_file_food.getvalue() | |
with open(uploaded_file_food.name, "wb") as file: | |
file.write(bytes_data) | |
st.image( | |
uploaded_file_food, caption="Uploaded Image.", use_column_width=True | |
) | |
raw_image = Image.open(uploaded_file_food.name).convert("RGB") | |
################################################################################################################# | |
# VQA model inference. | |
with st.spinner( | |
text="Detecting food..." | |
): # Spinner to keep the application interactive. | |
# Model inference. | |
answer = generate_answer(raw_image, "Is there a food in the picture?")[ | |
0 | |
] | |
if answer == "yes": | |
st.success(f"Food detected? {answer}", icon="β") | |
question = "What is the food in the picture?" | |
food = generate_answer(image=raw_image, question=question)[0] | |
st.success(f"Food detected: {food}", icon="β ") | |
################################################################################################################# | |
# Heading 2. | |
heading2 = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 1rem"> | |
Ingredients</h3>""" | |
st.markdown(heading2, unsafe_allow_html=True) | |
################################################################################################################# | |
# Upload an image. | |
uploaded_file_ingredients = st.file_uploader( | |
label="Choose an image:", | |
key="ingredients", | |
help="An image of the ingredients that you want to use.", | |
) | |
################################################################################################################# | |
if uploaded_file_ingredients is not None: | |
# Display the uploaded image. | |
bytes_data = uploaded_file_ingredients.getvalue() | |
with open(uploaded_file_ingredients.name, "wb") as file: | |
file.write(bytes_data) | |
st.image( | |
uploaded_file_ingredients, | |
caption="Uploaded Image.", | |
use_column_width=True, | |
) | |
############################################################################################################# | |
# Object detection model inference. | |
with st.spinner( | |
text="Detecting Ingredients..." | |
): # Spinner to keep the application interactive. | |
# Model inference. | |
ingredients_list = generate_ingredients( | |
image=uploaded_file_ingredients.name | |
) | |
############################################################################################################# | |
# Display/Output the ingredients detected. | |
ingredients = [] | |
st.success(f"Ingredients:", icon="π") | |
for i, ingredient_dict in enumerate(ingredients_list): | |
ingredients.append(ingredient_dict["label"]) | |
st.write(i + 1, ingredient_dict["label"]) | |
############################################################################################################# | |
# Heading 3. | |
heading3 = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 1rem"> | |
ChefBot</h3>""" | |
st.markdown(heading3, unsafe_allow_html=True) | |
############################################################################################################# | |
# Dropdown menu. | |
chef = st.selectbox( | |
label="Select your chef:", | |
options=chefs, | |
help="Select your chef.", | |
) | |
############################################################################################################# | |
# Generate Recipe button | |
col1, col2, col3 = st.columns(3) | |
with col2: | |
button_recipe = st.button("Generate Recipe") | |
############################################################################################################# | |
if button_recipe: | |
######################################################################################################### | |
# Langchain + OpenAI gpt-3.5-turbo model inference. | |
with st.spinner( | |
text="Generating Recipe..." | |
): # Spinner to keep the application interactive. | |
# Model inference. | |
recipe = generate_recipe( | |
food=food, ingredients=ingredients, chef=chef | |
) | |
######################################################################################################### | |
# Storing the recipe in session storage for future runs. | |
st.session_state["recipe"] = recipe | |
######################################################################################################### | |
# Text to speech model inference. | |
with st.spinner( | |
text="Generating Audio..." | |
): # Spinner to keep the application interactive. | |
# Model inference. | |
generate_speech(response=recipe) | |
######################################################################################################### | |
# Display/Output the generated recipe in text and audio. | |
with st.expander(label="Recipe"): | |
st.write(recipe) | |
st.audio("audio.flac") | |
######################################################################################################### | |
# st.write(st.session_state) | |
############################################################################################################# | |
# Conversation with ChefBot. | |
if "recipe" in st.session_state: | |
######################################################################################################### | |
# Context for the ChefBot. Context is use to accumulate messages. | |
context = [ | |
{ | |
"role": "system", | |
"content": f""" | |
You are a ChefBot, an automated service to guide users on how to cook step by step. | |
You must sound like {chef}. | |
You must first greet the user. | |
You must help the user step by step with this recipe: {st.session_state['recipe']}. | |
After you have given all of the steps of the recipe, | |
you must thank the user and ask for user feedback both on the recipe and on your personality. | |
Do NOT repeat the steps of any recipe during the conversation with the user.""", | |
} | |
] | |
######################################################################################################### | |
# User input. | |
user_input = st.text_input( | |
label="User Input:", | |
key="user_input", | |
help="Follow up with the chef for any questions on the recipe.", | |
placeholder="Clarify step 1.", | |
) | |
######################################################################################################### | |
# Chat and Reset Chat buttons. | |
col1, col2, col3, col4, col5 = st.columns(5) | |
with col1: | |
button_chat = st.button("Chat") | |
with col5: | |
if st.button("Reset Chat"): | |
st.session_state.panels = [] | |
user_input = False | |
######################################################################################################### | |
# Reverse the structure/way of displaying messages. | |
if "panels" not in st.session_state: | |
st.session_state.panels = [] | |
######################################################################################################### | |
# If there is a user input or the chat button was clicked AND the input is not empty. | |
if (user_input or button_chat) and user_input != "": | |
# Context management. | |
prompt = user_input | |
context.append({"role": "user", "content": f"{prompt}"}) | |
# OpenAI gpt-3.5-turbo model inference. | |
with st.spinner(text="Generating Response..."): | |
response = get_completion_from_messages(context) | |
# Text to speech model inference. | |
with st.spinner(text="Generating Audio..."): | |
generate_speech(response=response) | |
# Context management. | |
context.append({"role": "assistant", "content": f"{response}"}) | |
# Appending the newly generated messages into the structure/way of displaying messages. | |
st.session_state.panels.append(("User:", prompt)) | |
st.session_state.panels.append(("Assistant:", response)) | |
######################################################################################################### | |
# Display/Output messages. | |
with st.expander("Conversation History", expanded=True): | |
for role, content in reversed(st.session_state.panels): | |
# User. | |
if role == "User:": | |
user = f"""<p align="left" style="font-family: monospace; font-size: 1rem;"> | |
<b style="color:#dadada">π€{role}</b> {content}</p>""" | |
st.markdown(user, unsafe_allow_html=True) | |
# ChefBot. | |
else: | |
st.audio("audio.flac") | |
assistant = f"""<p align="left" style="font-family: monospace; font-size: 1rem;"> | |
<b style="color:#dadada">π¨βπ³{chef}:</b> {content}</p>""" | |
st.markdown(assistant, unsafe_allow_html=True) | |
############################################################################################################# | |
except Exception as e: | |
# General exception/error handling. | |
st.error(e) | |
# GitHub repository of author. | |
st.markdown( | |
f""" | |
<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our | |
<a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
############################################################################################################################# | |
if __name__ == "__main__": | |
main() | |