Spaces:
Runtime error
Runtime error
import numpy as np # linear algebra | |
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | |
import pprint | |
import os | |
import ast | |
import gradio as gr | |
from gradio.themes.base import Base | |
import weaviate | |
from weaviate.embedded import EmbeddedOptions | |
from langchain_community.vectorstores import Weaviate | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.schema import Document | |
from langchain_community.chat_models import ChatOpenAI | |
from kaggle_secrets import UserSecretsClient | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain_core.messages import HumanMessage, SystemMessage | |
df = pd.read_csv('./RAW_recipes.csv') | |
# Variables | |
max_length = 231637 #total number of recipes aka rows | |
curr_len = 10000 # how much we want to process and embed | |
#Concatenate all rows into one string | |
curr_i = 0 | |
recipe_info = [] | |
for index, row in df.iterrows(): | |
if curr_i >= curr_len: | |
break | |
curr_i+=1 | |
name, id, minutes, contributor_id, submitted, tags, nutrition, n_steps, steps, description, ingredients, n_ingredients = row | |
#convert to list | |
nutrition = ast.literal_eval(nutrition) | |
steps = ast.literal_eval(steps) | |
#format nutrition | |
nutrition_map = ["Calorie"," Total Fat", 'Sugar', 'Sodium', 'Protein', 'Saturated Fat', 'Total Carbohydrate'] | |
nutrition_labeled = [] | |
for label, num in zip(nutrition_map, nutrition): | |
if label == "Calorie": | |
nutrition_labeled.append(f"{label} : {num} per serving") | |
else: | |
nutrition_labeled.append(f"{label} : {num} % daily value") | |
#format steps | |
for i in range(len(steps)): | |
steps[i] = f"{i+1}. " + steps[i] | |
recipe_info.append(f''' | |
{name} : {minutes} minutes, submitted on {submitted} | |
description: {description}, | |
ingredients: {ingredients} | |
number of ingredients: {n_ingredients} | |
tags: {tags}, nutrition: {nutrition_labeled}, total steps: {n_steps} | |
steps: {steps} | |
'''.replace("\r", "").replace("\n", "")) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150) | |
#split into recipe_info into chunks | |
docs = [] | |
for doc in recipe_info: | |
# Wrap each string in a Document object | |
document = Document(page_content=doc) # create a Document object with the content | |
chunk = text_splitter.split_documents([document]) # Pass a list of Document objects | |
docs.append(chunk) | |
# merge all chunks into one | |
merged_documents = [] | |
for doc in docs: | |
merged_documents.extend(doc) | |
# Hugging Face model for embeddings. | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
model_kwargs = {'device': 'cpu'} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
) | |
#initialize weaviate client | |
client = weaviate.Client( | |
embedded_options = EmbeddedOptions() | |
) | |
vector_search = Weaviate.from_documents( | |
client = client, | |
documents = merged_documents, | |
embedding = embeddings, | |
by_text = False | |
) | |
# Instantiate Weaviate Vector Search as a retriever | |
# Basic RAG. | |
# k to search for only the 25 most relevant documents. | |
# score_threshold to use only documents with a relevance score above 0.77. | |
k = 10 | |
score_threshold = 0.77 | |
retriever = vector_search.as_retriever( | |
search_type = "mmr", | |
search_kwargs = { | |
"k": k, | |
"score_threshold": score_threshold | |
} | |
) | |
template = """ | |
You are an assistant for question-answering tasks. | |
Use the following pieces of retrieved context to answer the question at the end. | |
The following pieces of retrieved context are recipes. | |
If you don't know the answer, just say that you don't know. Don't try to make up an answer. | |
Dont say anthing mean or offensive. | |
Context: {context} | |
Question: {question} | |
""" | |
custom_rag_prompt = ChatPromptTemplate.from_template(template) | |
llm = ChatOpenAI( | |
model_name="gpt-3.5-turbo", | |
temperature=0.2) | |
# Regular chain format: chain = prompt | model | output_parser | |
rag_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| custom_rag_prompt | |
| llm | |
| StrOutputParser() | |
) | |
def get_response(query): | |
return rag_chain.invoke(query) | |
with gr.Blocks(theme=Base(), title="RAG Recipe AI") as demo: | |
gr.Markdown(""" | |
# RAG Recipe AI | |
This model will answer all your recipe-related questions. | |
Enter a question about a recipe, and the system will return an answer based on 10,000 food.com recipes stored in the vector database. \n | |
Features Considered: \n | |
\t - Cook Time | |
\t - Nutrition Information | |
\t - Steps | |
\t - Ingredients | |
\t - Dish Description | |
Sample Queries: \n | |
\t - What is an easy dessert I can make with apples? | |
\t - What is the nutritional information of a Caesar salad? | |
\t - How many calories is in an average American burger? | |
""") | |
textbox = gr.Textbox(label="Question:") | |
with gr.Row(): | |
button = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
output1 = gr.Textbox(lines=1, max_lines=10, label="Answer:") | |
# Call get_response function upon clicking the Submit button. | |
button.click(get_response, textbox, outputs=[output1]) | |
demo.launch() | |