tonyliu404 commited on
Commit
ef2318a
·
verified ·
1 Parent(s): a8be7a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -1
app.py CHANGED
@@ -18,4 +18,138 @@ from langchain.schema.output_parser import StrOutputParser
18
  from langchain_core.messages import HumanMessage, SystemMessage
19
 
20
  df = pd.read_csv('./RAW_recipes.csv')
21
- print(df.head())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from langchain_core.messages import HumanMessage, SystemMessage
19
 
20
  df = pd.read_csv('./RAW_recipes.csv')
21
+
22
+ # Variables
23
+ max_length = 231637 #total number of recipes aka rows
24
+ curr_len = 10000 # how much we want to process and embed
25
+
26
+ #Concatenate all rows into one string
27
+ curr_i = 0
28
+ recipe_info = []
29
+ for index, row in df.iterrows():
30
+ if curr_i >= curr_len:
31
+ break
32
+ curr_i+=1
33
+ name, id, minutes, contributor_id, submitted, tags, nutrition, n_steps, steps, description, ingredients, n_ingredients = row
34
+
35
+ #convert to list
36
+ nutrition = ast.literal_eval(nutrition)
37
+ steps = ast.literal_eval(steps)
38
+
39
+ #format nutrition
40
+ nutrition_map = ["Calorie"," Total Fat", 'Sugar', 'Sodium', 'Protein', 'Saturated Fat', 'Total Carbohydrate']
41
+ nutrition_labeled = []
42
+ for label, num in zip(nutrition_map, nutrition):
43
+ nutrition_labeled.append(f"{label} : {num} % daily value")
44
+
45
+ #format steps
46
+ for i in range(len(steps)):
47
+ steps[i] = f"{i+1}. " + steps[i]
48
+ recipe_info.append(f'''
49
+ {name} : {minutes} minutes, submitted on {submitted}
50
+ description: {description},
51
+ ingredients: {ingredients}
52
+ number of ingredients: {n_ingredients}
53
+ tags: {tags}, nutrition: {nutrition_labeled}, total steps: {n_steps}
54
+ steps: {steps}
55
+ '''.replace("\r", "").replace("\n", ""))
56
+
57
+
58
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
59
+
60
+ #split into recipe_info into chunks
61
+ docs = []
62
+ for doc in recipe_info:
63
+ # Wrap each string in a Document object
64
+ document = Document(page_content=doc) # create a Document object with the content
65
+ chunk = text_splitter.split_documents([document]) # Pass a list of Document objects
66
+ docs.append(chunk)
67
+
68
+ # merge all chunks into one
69
+ merged_documents = []
70
+ for doc in docs:
71
+ merged_documents.extend(doc)
72
+
73
+ # Hugging Face model for embeddings.
74
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
75
+ model_kwargs = {'device': 'cpu'}
76
+ embeddings = HuggingFaceEmbeddings(
77
+ model_name=model_name,
78
+ model_kwargs=model_kwargs,
79
+ )
80
+
81
+ #initialize weaviate client
82
+ client = weaviate.Client(
83
+ embedded_options = EmbeddedOptions()
84
+ )
85
+
86
+
87
+ vector_search = Weaviate.from_documents(
88
+ client = client,
89
+ documents = merged_documents,
90
+ embedding = embeddings,
91
+ by_text = False
92
+ )
93
+
94
+
95
+ # Instantiate Weaviate Vector Search as a retriever
96
+
97
+ # Basic RAG.
98
+ # k to search for only the 25 most relevant documents.
99
+ # score_threshold to use only documents with a relevance score above 0.77.
100
+ k = 10
101
+ score_threshold = 0.77
102
+
103
+ retriever = vector_search.as_retriever(
104
+ search_type = "mmr",
105
+ search_kwargs = {
106
+ "k": k,
107
+ "score_threshold": score_threshold
108
+ }
109
+ )
110
+
111
+ template = """
112
+ You are an assistant for question-answering tasks.
113
+ Use the following pieces of retrieved context to answer the question at the end.
114
+ The following pieces of retrieved context are recipes.
115
+ If you don't know the answer, just say that you don't know. Don't try to make up an answer.
116
+ Dont say anthing mean or offensive.
117
+
118
+ Context: {context}
119
+
120
+ Question: {question}
121
+ """
122
+
123
+ custom_rag_prompt = ChatPromptTemplate.from_template(template)
124
+
125
+ llm = ChatOpenAI(
126
+ model_name="gpt-3.5-turbo",
127
+ temperature=0.2)
128
+
129
+ # Regular chain format: chain = prompt | model | output_parser
130
+ rag_chain = (
131
+ {"context": retriever, "question": RunnablePassthrough()}
132
+ | custom_rag_prompt
133
+ | llm
134
+ | StrOutputParser()
135
+ )
136
+
137
+
138
+ def get_response(query):
139
+ return rag_chain.invoke(query)
140
+
141
+
142
+ with gr.Blocks(theme=Base(), title="RAG Recipe AI") as demo:
143
+ gr.Markdown("RAG Recipe AI")
144
+ textbox = gr.Textbox(label="Question:")
145
+ with gr.Row():
146
+ button = gr.Button("Submit", variant="primary")
147
+ with gr.Column():
148
+ output1 = gr.Textbox(lines=1, max_lines=10, label="Answer:")
149
+ # Call get_response function upon clicking the Submit button.
150
+ button.click(get_response, textbox, outputs=[output1])
151
+
152
+ demo.launch()
153
+
154
+
155
+