TimInf commited on
Commit
5a007ca
Β·
verified Β·
1 Parent(s): 9837e7a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +462 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ import json
7
+
8
+ # Load RecipeBERT model (for semantic ingredient combination)
9
+ bert_model_name = "alexdseo/RecipeBERT"
10
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
11
+ bert_model = AutoModel.from_pretrained(bert_model_name)
12
+ bert_model.eval()
13
+
14
+ # Load T5 recipe generation model
15
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
16
+ t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
17
+ t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
18
+
19
+ # Token mapping for T5 model output processing
20
+ special_tokens = t5_tokenizer.all_special_tokens
21
+ tokens_map = {
22
+ "<sep>": "--",
23
+ "<section>": "\n"
24
+ }
25
+
26
+
27
+ def get_embedding(text):
28
+ """Computes embedding for a text with Mean Pooling over all tokens"""
29
+ inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
30
+ with torch.no_grad():
31
+ outputs = bert_model(**inputs)
32
+
33
+ # Mean Pooling - take average of all token embeddings
34
+ attention_mask = inputs['attention_mask']
35
+ token_embeddings = outputs.last_hidden_state
36
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
37
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
38
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
39
+
40
+ return (sum_embeddings / sum_mask).squeeze(0)
41
+
42
+
43
+ def average_embedding(embedding_list):
44
+ """Computes the average of a list of embeddings"""
45
+ tensors = torch.stack([emb for _, emb in embedding_list])
46
+ return tensors.mean(dim=0)
47
+
48
+
49
+ def get_cosine_similarity(vec1, vec2):
50
+ """Computes the cosine similarity between two vectors"""
51
+ if torch.is_tensor(vec1):
52
+ vec1 = vec1.detach().numpy()
53
+ if torch.is_tensor(vec2):
54
+ vec2 = vec2.detach().numpy()
55
+
56
+ # Make sure vectors have the right shape (flatten if necessary)
57
+ vec1 = vec1.flatten()
58
+ vec2 = vec2.flatten()
59
+
60
+ dot_product = np.dot(vec1, vec2)
61
+ norm_a = np.linalg.norm(vec1)
62
+ norm_b = np.linalg.norm(vec2)
63
+
64
+ # Avoid division by zero
65
+ if norm_a == 0 or norm_b == 0:
66
+ return 0
67
+
68
+ return dot_product / (norm_a * norm_b)
69
+
70
+
71
+ def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
72
+ """Computes combined score considering both similarity to average and individual ingredients"""
73
+ results = []
74
+
75
+ for name, emb in embedding_list:
76
+ # Similarity to average vector
77
+ avg_similarity = get_cosine_similarity(query_vector, emb)
78
+
79
+ # Average similarity to individual ingredients
80
+ individual_similarities = [get_cosine_similarity(good_emb, emb)
81
+ for _, good_emb in all_good_embeddings]
82
+ avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
83
+
84
+ # Combined score (weighted average)
85
+ combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
86
+
87
+ results.append((name, emb, combined_score))
88
+
89
+ # Sort by combined score (descending)
90
+ results.sort(key=lambda x: x[2], reverse=True)
91
+ return results
92
+
93
+
94
+ def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
95
+ """
96
+ Finds the best ingredients based on RecipeBERT embeddings.
97
+ """
98
+ # Clean and prepare ingredient lists
99
+ required_ingredients = [ing.strip() for ing in required_ingredients if ing.strip()]
100
+ available_ingredients = [ing.strip() for ing in available_ingredients if ing.strip()]
101
+
102
+ # Remove duplicates
103
+ required_ingredients = list(set(required_ingredients))
104
+ available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
105
+
106
+ # Special case: If no required ingredients, randomly select one from available ingredients
107
+ if not required_ingredients and available_ingredients:
108
+ random_ingredient = random.choice(available_ingredients)
109
+ required_ingredients = [random_ingredient]
110
+ available_ingredients = [i for i in available_ingredients if i != random_ingredient]
111
+
112
+ # If still no ingredients or already at max capacity
113
+ if not required_ingredients or len(required_ingredients) >= max_ingredients:
114
+ return required_ingredients[:max_ingredients]
115
+
116
+ # If no additional ingredients available
117
+ if not available_ingredients:
118
+ return required_ingredients
119
+
120
+ # Calculate embeddings for all ingredients
121
+ embed_required = [(e, get_embedding(e)) for e in required_ingredients]
122
+ embed_available = [(e, get_embedding(e)) for e in available_ingredients]
123
+
124
+ # Number of ingredients to add
125
+ num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
126
+
127
+ # Copy required ingredients to final list
128
+ final_ingredients = embed_required.copy()
129
+
130
+ # Add best ingredients
131
+ for _ in range(num_to_add):
132
+ # Calculate average vector of current combination
133
+ avg = average_embedding(final_ingredients)
134
+
135
+ # Calculate combined scores for all candidates
136
+ candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
137
+
138
+ # If no candidates left, break
139
+ if not candidates:
140
+ break
141
+
142
+ # Choose best ingredient
143
+ best_name, best_embedding, _ = candidates[0]
144
+
145
+ # Add best ingredient to final list
146
+ final_ingredients.append((best_name, best_embedding))
147
+
148
+ # Remove ingredient from available ingredients
149
+ embed_available = [item for item in embed_available if item[0] != best_name]
150
+
151
+ # Extract only ingredient names
152
+ return [name for name, _ in final_ingredients]
153
+
154
+
155
+ def skip_special_tokens(text, special_tokens):
156
+ """Removes special tokens from text"""
157
+ for token in special_tokens:
158
+ text = text.replace(token, "")
159
+ return text
160
+
161
+
162
+ def target_postprocessing(texts, special_tokens):
163
+ """Post-processes generated text"""
164
+ if not isinstance(texts, list):
165
+ texts = [texts]
166
+
167
+ new_texts = []
168
+ for text in texts:
169
+ text = skip_special_tokens(text, special_tokens)
170
+
171
+ for k, v in tokens_map.items():
172
+ text = text.replace(k, v)
173
+
174
+ new_texts.append(text)
175
+
176
+ return new_texts
177
+
178
+
179
+ def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
180
+ """Validates if the recipe contains approximately the expected ingredients."""
181
+ recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
182
+ expected_count = len(expected_ingredients)
183
+ return abs(recipe_count - expected_count) <= tolerance
184
+
185
+
186
+ def generate_recipe_with_t5(ingredients_list, max_retries=5):
187
+ """Generates a recipe using the T5 recipe generation model with validation."""
188
+ original_ingredients = ingredients_list.copy()
189
+
190
+ for attempt in range(max_retries):
191
+ try:
192
+ # For retries after the first attempt, shuffle the ingredients
193
+ if attempt > 0:
194
+ current_ingredients = original_ingredients.copy()
195
+ random.shuffle(current_ingredients)
196
+ else:
197
+ current_ingredients = ingredients_list
198
+
199
+ # Format ingredients as a comma-separated string
200
+ ingredients_string = ", ".join(current_ingredients)
201
+ prefix = "items: "
202
+
203
+ # Generation settings
204
+ generation_kwargs = {
205
+ "max_length": 512,
206
+ "min_length": 64,
207
+ "do_sample": True,
208
+ "top_k": 60,
209
+ "top_p": 0.95
210
+ }
211
+
212
+ # Tokenize input
213
+ inputs = t5_tokenizer(
214
+ prefix + ingredients_string,
215
+ max_length=256,
216
+ padding="max_length",
217
+ truncation=True,
218
+ return_tensors="jax"
219
+ )
220
+
221
+ # Generate text
222
+ output_ids = t5_model.generate(
223
+ input_ids=inputs.input_ids,
224
+ attention_mask=inputs.attention_mask,
225
+ **generation_kwargs
226
+ )
227
+
228
+ # Decode and post-process
229
+ generated = output_ids.sequences
230
+ generated_text = target_postprocessing(
231
+ t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
232
+ special_tokens
233
+ )[0]
234
+
235
+ # Parse sections
236
+ recipe = {}
237
+ sections = generated_text.split("\n")
238
+ for section in sections:
239
+ section = section.strip()
240
+ if section.startswith("title:"):
241
+ recipe["title"] = section.replace("title:", "").strip().capitalize()
242
+ elif section.startswith("ingredients:"):
243
+ ingredients_text = section.replace("ingredients:", "").strip()
244
+ recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if
245
+ item.strip()]
246
+ elif section.startswith("directions:"):
247
+ directions_text = section.replace("directions:", "").strip()
248
+ recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if
249
+ step.strip()]
250
+
251
+ # If title is missing, create one
252
+ if "title" not in recipe:
253
+ recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}"
254
+
255
+ # Ensure all sections exist
256
+ if "ingredients" not in recipe:
257
+ recipe["ingredients"] = current_ingredients
258
+ if "directions" not in recipe:
259
+ recipe["directions"] = ["No directions generated"]
260
+
261
+ # Validate the recipe
262
+ if validate_recipe_ingredients(recipe["ingredients"], original_ingredients, tolerance=1):
263
+ return recipe
264
+ else:
265
+ if attempt == max_retries - 1:
266
+ return recipe
267
+
268
+ except Exception as e:
269
+ if attempt == max_retries - 1:
270
+ return {
271
+ "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
272
+ "ingredients": original_ingredients,
273
+ "directions": ["Error generating recipe instructions"]
274
+ }
275
+
276
+ # Fallback
277
+ return {
278
+ "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
279
+ "ingredients": original_ingredients,
280
+ "directions": ["Error generating recipe instructions"]
281
+ }
282
+
283
+
284
+ def generate_recipe_interface(required_ingredients_text, available_ingredients_text, max_ingredients):
285
+ """Main interface function for Gradio"""
286
+ try:
287
+ # Parse ingredient inputs
288
+ required_ingredients = []
289
+ available_ingredients = []
290
+
291
+ if required_ingredients_text:
292
+ required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
293
+
294
+ if available_ingredients_text:
295
+ available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
296
+
297
+ # Validate inputs
298
+ if not required_ingredients and not available_ingredients:
299
+ return "❌ **Error:** Please provide at least some ingredients!", "", "", ""
300
+
301
+ # Find best ingredient combination
302
+ optimized_ingredients = find_best_ingredients(
303
+ required_ingredients,
304
+ available_ingredients,
305
+ max_ingredients
306
+ )
307
+
308
+ # Generate recipe
309
+ recipe = generate_recipe_with_t5(optimized_ingredients)
310
+
311
+ # Format output
312
+ title = f"🍽️ **{recipe['title']}**"
313
+
314
+ ingredients_formatted = "## πŸ“‹ Ingredients:\n" + "\n".join([f"β€’ {ing}" for ing in recipe['ingredients']])
315
+
316
+ directions_formatted = "## πŸ‘¨β€πŸ³ Instructions:\n" + "\n".join(
317
+ [f"{i + 1}. {step}" for i, step in enumerate(recipe['directions'])])
318
+
319
+ used_ingredients = "## βœ… Used Ingredients:\n" + ", ".join(optimized_ingredients)
320
+
321
+ return title, ingredients_formatted, directions_formatted, used_ingredients
322
+
323
+ except Exception as e:
324
+ return f"❌ **Error:** {str(e)}", "", "", ""
325
+
326
+
327
+ def generate_recipe_api(required_ingredients_text, available_ingredients_text, max_ingredients):
328
+ """API-compatible function that returns JSON format"""
329
+ try:
330
+ # Parse ingredient inputs
331
+ required_ingredients = []
332
+ available_ingredients = []
333
+
334
+ if required_ingredients_text:
335
+ required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
336
+
337
+ if available_ingredients_text:
338
+ available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
339
+
340
+ # Validate inputs
341
+ if not required_ingredients and not available_ingredients:
342
+ return json.dumps({"error": "No ingredients provided"}, indent=2)
343
+
344
+ # Find best ingredient combination
345
+ optimized_ingredients = find_best_ingredients(
346
+ required_ingredients,
347
+ available_ingredients,
348
+ max_ingredients
349
+ )
350
+
351
+ # Generate recipe
352
+ recipe = generate_recipe_with_t5(optimized_ingredients)
353
+
354
+ # Format for API response
355
+ api_response = {
356
+ 'title': recipe['title'],
357
+ 'ingredients': recipe['ingredients'],
358
+ 'directions': recipe['directions'],
359
+ 'used_ingredients': optimized_ingredients
360
+ }
361
+
362
+ return json.dumps(api_response, indent=2, ensure_ascii=False)
363
+
364
+ except Exception as e:
365
+ return json.dumps({"error": f"Error in recipe generation: {str(e)}"}, indent=2)
366
+
367
+
368
+ # Create Gradio interface
369
+ with gr.Blocks(title="🍳 AI Recipe Generator", theme=gr.themes.Soft()) as demo:
370
+ gr.Markdown("""
371
+ # 🍳 AI Recipe Generator
372
+
373
+ Generate delicious recipes using AI! This tool uses **RecipeBERT** to find the best ingredient combinations and **T5** to generate complete recipes.
374
+
375
+ ## How to use:
376
+ 1. **Required Ingredients:** Enter ingredients you must use (comma-separated)
377
+ 2. **Available Ingredients:** Enter additional ingredients you have available (comma-separated)
378
+ 3. **Max Ingredients:** Set the maximum number of ingredients for your recipe
379
+ 4. Click **Generate Recipe** to create your personalized recipe!
380
+ """)
381
+
382
+ with gr.Tab("🍽️ Recipe Generator"):
383
+ with gr.Row():
384
+ with gr.Column():
385
+ required_ingredients = gr.Textbox(
386
+ label="🎯 Required Ingredients",
387
+ placeholder="chicken, rice, onions",
388
+ info="Ingredients that must be included in the recipe (comma-separated)"
389
+ )
390
+ available_ingredients = gr.Textbox(
391
+ label="πŸ₯• Available Ingredients",
392
+ placeholder="garlic, tomatoes, basil, cheese",
393
+ info="Additional ingredients you have available (comma-separated)"
394
+ )
395
+ max_ingredients = gr.Slider(
396
+ minimum=3, maximum=12, value=7, step=1,
397
+ label="πŸ“Š Maximum Ingredients",
398
+ info="Maximum number of ingredients to use in the recipe"
399
+ )
400
+ generate_btn = gr.Button("πŸš€ Generate Recipe", variant="primary", size="lg")
401
+
402
+ with gr.Column():
403
+ recipe_title = gr.Markdown()
404
+ used_ingredients = gr.Markdown()
405
+
406
+ with gr.Row():
407
+ with gr.Column():
408
+ recipe_ingredients = gr.Markdown()
409
+ with gr.Column():
410
+ recipe_directions = gr.Markdown()
411
+
412
+ with gr.Tab("πŸ”Œ API Format"):
413
+ gr.Markdown("""
414
+ ## API Response Format
415
+ This tab shows the response in JSON format, compatible with your Flutter app.
416
+ """)
417
+
418
+ with gr.Row():
419
+ with gr.Column():
420
+ api_required = gr.Textbox(
421
+ label="Required Ingredients",
422
+ placeholder="chicken, rice, onions"
423
+ )
424
+ api_available = gr.Textbox(
425
+ label="Available Ingredients",
426
+ placeholder="garlic, tomatoes, basil"
427
+ )
428
+ api_max = gr.Slider(
429
+ minimum=3, maximum=12, value=7, step=1,
430
+ label="Max Ingredients"
431
+ )
432
+ api_generate_btn = gr.Button("Generate JSON", variant="secondary")
433
+
434
+ with gr.Column():
435
+ api_output = gr.Code(language="json", label="API Response")
436
+
437
+ # Event handlers
438
+ generate_btn.click(
439
+ fn=generate_recipe_interface,
440
+ inputs=[required_ingredients, available_ingredients, max_ingredients],
441
+ outputs=[recipe_title, recipe_ingredients, recipe_directions, used_ingredients]
442
+ )
443
+
444
+ api_generate_btn.click(
445
+ fn=generate_recipe_api,
446
+ inputs=[api_required, api_available, api_max],
447
+ outputs=[api_output]
448
+ )
449
+
450
+ # Example inputs
451
+ gr.Examples(
452
+ examples=[
453
+ ["chicken, rice", "onions, garlic, tomatoes, basil", 6],
454
+ ["eggs, flour", "milk, sugar, vanilla, butter", 7],
455
+ ["salmon", "lemon, dill, potatoes, asparagus", 5],
456
+ ["", "beef, potatoes, carrots, onions, garlic", 6]
457
+ ],
458
+ inputs=[required_ingredients, available_ingredients, max_ingredients]
459
+ )
460
+
461
+ if __name__ == "__main__":
462
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ numpy>=1.21.0
5
+ flax>=0.7.0
6
+ jax>=0.4.0
7
+ jaxlib>=0.4.0