TimInf commited on
Commit
288724d
·
verified ·
1 Parent(s): 30f7e2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -261
app.py CHANGED
@@ -1,229 +1,139 @@
1
- import gradio as gr
2
  from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
3
  import torch
4
  import numpy as np
5
  import random
6
- import json # Beibehalten, da es in flutter_api_generate_recipe verwendet wird
 
 
 
 
7
 
8
  # Lade RecipeBERT Modell (für semantische Zutat-Kombination)
9
  bert_model_name = "alexdseo/RecipeBERT"
10
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
11
  bert_model = AutoModel.from_pretrained(bert_model_name)
12
- bert_model.eval() # Setze das Modell in den Evaluationsmodus
13
 
14
  # Lade T5 Rezeptgenerierungsmodell
15
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
16
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
17
  t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
18
 
19
- # Token Mapping für die T5 Modell-Ausgabe
20
  special_tokens = t5_tokenizer.all_special_tokens
21
  tokens_map = {
22
  "<sep>": "--",
23
  "<section>": "\n"
24
  }
25
 
 
 
 
 
 
26
  def get_embedding(text):
27
- """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens"""
28
  inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
29
  with torch.no_grad():
30
  outputs = bert_model(**inputs)
31
-
32
- # Mean Pooling - Mittelwert aller Token-Embeddings
33
  attention_mask = inputs['attention_mask']
34
  token_embeddings = outputs.last_hidden_state
35
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
36
  sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
37
  sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
38
-
39
  return (sum_embeddings / sum_mask).squeeze(0)
40
 
41
  def average_embedding(embedding_list):
42
- """Berechnet den Durchschnitt einer Liste von Embeddings"""
43
  tensors = torch.stack([emb for _, emb in embedding_list])
44
  return tensors.mean(dim=0)
45
 
46
  def get_cosine_similarity(vec1, vec2):
47
- """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren"""
48
- if torch.is_tensor(vec1):
49
- vec1 = vec1.detach().numpy()
50
- if torch.is_tensor(vec2):
51
- vec2 = vec2.detach().numpy()
52
-
53
- # Stelle sicher, dass die Vektoren die richtige Form haben (flachen sie bei Bedarf ab)
54
  vec1 = vec1.flatten()
55
  vec2 = vec2.flatten()
56
-
57
  dot_product = np.dot(vec1, vec2)
58
  norm_a = np.linalg.norm(vec1)
59
  norm_b = np.linalg.norm(vec2)
60
-
61
- # Division durch Null vermeiden
62
- if norm_a == 0 or norm_b == 0:
63
- return 0
64
-
65
  return dot_product / (norm_a * norm_b)
66
 
67
  def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
68
- """Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten"""
69
  results = []
70
-
71
  for name, emb in embedding_list:
72
- # Ähnlichkeit zum Durchschnittsvektor
73
  avg_similarity = get_cosine_similarity(query_vector, emb)
74
-
75
- # Durchschnittliche Ähnlichkeit zu einzelnen Zutaten
76
- individual_similarities = [get_cosine_similarity(good_emb, emb)
77
- for _, good_emb in all_good_embeddings]
78
  avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
79
-
80
- # Kombinierter Score (gewichteter Durchschnitt)
81
  combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
82
-
83
  results.append((name, emb, combined_score))
84
-
85
- # Sortiere nach kombiniertem Score (absteigend)
86
  results.sort(key=lambda x: x[2], reverse=True)
87
  return results
88
 
89
  def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
90
- """
91
- Findet die besten Zutaten basierend auf RecipeBERT Embeddings.
92
- """
93
- # Stelle sicher, dass keine Duplikate in den Listen sind
94
  required_ingredients = list(set(required_ingredients))
95
  available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
96
-
97
- # Sonderfall: Wenn keine benötigten Zutaten vorhanden sind, wähle zufällig eine aus den verfügbaren Zutaten
98
  if not required_ingredients and available_ingredients:
99
  random_ingredient = random.choice(available_ingredients)
100
  required_ingredients = [random_ingredient]
101
  available_ingredients = [i for i in available_ingredients if i != random_ingredient]
102
- # print(f"Keine benötigten Zutaten angegeben. Zufällig ausgewählt: {random_ingredient}")
103
-
104
- # Wenn immer noch keine Zutaten vorhanden oder bereits maximale Kapazität erreicht ist
105
  if not required_ingredients or len(required_ingredients) >= max_ingredients:
106
  return required_ingredients[:max_ingredients]
107
-
108
- # Wenn keine zusätzlichen Zutaten verfügbar sind
109
  if not available_ingredients:
110
  return required_ingredients
111
-
112
- # Berechne Embeddings für alle Zutaten
113
  embed_required = [(e, get_embedding(e)) for e in required_ingredients]
114
  embed_available = [(e, get_embedding(e)) for e in available_ingredients]
115
-
116
- # Anzahl der hinzuzufügenden Zutaten
117
  num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
118
-
119
- # Kopiere benötigte Zutaten in die endgültige Liste
120
  final_ingredients = embed_required.copy()
121
-
122
- # Füge die besten Zutaten hinzu
123
  for _ in range(num_to_add):
124
- # Berechne den Durchschnittsvektor der aktuellen Kombination
125
  avg = average_embedding(final_ingredients)
126
-
127
- # Berechne kombinierte Scores für alle Kandidaten
128
  candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
129
-
130
- # Wenn keine Kandidaten mehr übrig sind, breche ab
131
- if not candidates:
132
- break
133
-
134
- # Wähle die beste Zutat
135
  best_name, best_embedding, _ = candidates[0]
136
-
137
- # Füge die beste Zutat zur endgültigen Liste hinzu
138
  final_ingredients.append((best_name, best_embedding))
139
-
140
- # Entferne die Zutat aus den verfügbaren Zutaten
141
  embed_available = [item for item in embed_available if item[0] != best_name]
142
-
143
- # Extrahiere nur die Zutatennamen
144
  return [name for name, _ in final_ingredients]
145
 
146
  def skip_special_tokens(text, special_tokens):
147
- """Entfernt spezielle Tokens aus dem Text"""
148
- for token in special_tokens:
149
- text = text.replace(token, "")
150
  return text
151
 
152
  def target_postprocessing(texts, special_tokens):
153
- """Post-processed generierten Text"""
154
- if not isinstance(texts, list):
155
- texts = [texts]
156
-
157
  new_texts = []
158
  for text in texts:
159
  text = skip_special_tokens(text, special_tokens)
160
-
161
- for k, v in tokens_map.items():
162
- text = text.replace(k, v)
163
-
164
  new_texts.append(text)
165
-
166
  return new_texts
167
 
168
  def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
169
- """
170
- Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält.
171
- """
172
  recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
173
  expected_count = len(expected_ingredients)
174
  return abs(recipe_count - expected_count) == tolerance
175
 
176
  def generate_recipe_with_t5(ingredients_list, max_retries=5):
177
- """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung."""
178
  original_ingredients = ingredients_list.copy()
179
-
180
  for attempt in range(max_retries):
181
  try:
182
- # Für Wiederholungsversuche nach dem ersten Versuch, mische die Zutaten
183
  if attempt > 0:
184
  current_ingredients = original_ingredients.copy()
185
  random.shuffle(current_ingredients)
186
  else:
187
  current_ingredients = ingredients_list
188
-
189
- # Formatiere Zutaten als kommaseparierten String
190
  ingredients_string = ", ".join(current_ingredients)
191
  prefix = "items: "
192
-
193
- # Generationseinstellungen
194
  generation_kwargs = {
195
- "max_length": 512,
196
- "min_length": 64,
197
- "do_sample": True,
198
- "top_k": 60,
199
- "top_p": 0.95
200
  }
201
- # print(f"Versuch {attempt + 1}: {prefix + ingredients_string}")
202
-
203
- # Tokenisiere Eingabe
204
  inputs = t5_tokenizer(
205
- prefix + ingredients_string,
206
- max_length=256,
207
- padding="max_length",
208
- truncation=True,
209
- return_tensors="jax"
210
  )
211
-
212
- # Generiere Text
213
  output_ids = t5_model.generate(
214
- input_ids=inputs.input_ids,
215
- attention_mask=inputs.attention_mask,
216
- **generation_kwargs
217
  )
218
-
219
- # Dekodieren und Nachbearbeiten
220
  generated = output_ids.sequences
221
- generated_text = target_postprocessing(
222
- t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
223
- special_tokens
224
- )[0]
225
-
226
- # Abschnitte parsen
227
  recipe = {}
228
  sections = generated_text.split("\n")
229
  for section in sections:
@@ -236,65 +146,40 @@ def generate_recipe_with_t5(ingredients_list, max_retries=5):
236
  elif section.startswith("directions:"):
237
  directions_text = section.replace("directions:", "").strip()
238
  recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
239
-
240
- # Wenn der Titel fehlt, erstelle einen
241
  if "title" not in recipe:
242
  recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
243
-
244
- # Stelle sicher, dass alle Abschnitte existieren
245
  if "ingredients" not in recipe:
246
  recipe["ingredients"] = current_ingredients
247
  if "directions" not in recipe:
248
  recipe["directions"] = ["Keine Anweisungen generiert"]
249
-
250
- # Validiere das Rezept
251
  if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
252
- # print(f"Erfolg bei Versuch {attempt + 1}: Rezept hat die richtige Anzahl von Zutaten")
253
  return recipe
254
  else:
255
- # print(f"Versuch {attempt + 1} fehlgeschlagen: Erwartet {len(original_ingredients)} Zutaten, erhalten {len(recipe['ingredients'])}")
256
- if attempt == max_retries - 1:
257
- # print("Maximale Wiederholungsversuche erreicht, letztes generiertes Rezept wird zurückgegeben")
258
- return recipe
259
-
260
  except Exception as e:
261
- # print(f"Fehler bei der Rezeptgenerierung Versuch {attempt + 1}: {str(e)}")
262
  if attempt == max_retries - 1:
263
  return {
264
  "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
265
  "ingredients": original_ingredients,
266
  "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
267
  }
268
-
269
- # Fallback (sollte nicht erreicht werden)
270
  return {
271
  "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
272
  "ingredients": original_ingredients,
273
  "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
274
  }
275
 
276
- # Diese Funktion wird von der Gradio-UI und der FastAPI-Route aufgerufen.
277
- # Sie ist für die Kernlogik zuständig.
278
  def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
279
  """
280
  Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
281
- Ausgelagert, um von verschiedenen Endpunkten aufgerufen zu werden.
282
  """
283
  if not required_ingredients and not available_ingredients:
284
  return {"error": "Keine Zutaten angegeben"}
285
-
286
  try:
287
- # Optimale Zutaten finden
288
  optimized_ingredients = find_best_ingredients(
289
- required_ingredients,
290
- available_ingredients,
291
- max_ingredients
292
  )
293
-
294
- # Rezept mit optimierten Zutaten generieren
295
  recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
296
-
297
- # Ergebnis formatieren
298
  result = {
299
  'title': recipe['title'],
300
  'ingredients': recipe['ingredients'],
@@ -302,126 +187,40 @@ def process_recipe_request_logic(required_ingredients, available_ingredients, ma
302
  'used_ingredients': optimized_ingredients
303
  }
304
  return result
305
-
306
  except Exception as e:
307
  return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
308
 
309
- def flutter_api_generate_recipe(ingredients_data: str): # Typ-Hint für Klarheit
310
- """
311
- Diese Funktion wird vom 'hugging_face_chat_gradio'-Paket über die API aufgerufen.
312
- Sie erwartet einen JSON-STRING als Eingabe.
313
- """
314
- try:
315
- # Der 'hugging_face_chat_gradio'-Client sendet das Payload als String.
316
- data = json.loads(ingredients_data)
317
-
318
- required_ingredients = data.get('required_ingredients', [])
319
- available_ingredients = data.get('available_ingredients', [])
320
- max_ingredients = data.get('max_ingredients', 7)
321
- max_retries = data.get('max_retries', 5)
322
-
323
- # Rufe die Kernlogik auf
324
- result_dict = process_recipe_request_logic(
325
- required_ingredients, available_ingredients, max_ingredients, max_retries
326
- )
327
- return json.dumps(result_dict) # Gibt einen JSON-STRING zurück
328
-
329
- except Exception as e:
330
- # Logge den Fehler für Debugging im Space-Log
331
- print(f"Error in flutter_api_generate_recipe: {str(e)}")
332
- return json.dumps({"error": f"Internal API Error: {str(e)}"})
333
-
334
- def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients_val, max_retries_val):
335
- """Gradio UI Funktion für die Web-Oberfläche"""
336
- try:
337
- required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
338
- available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
339
-
340
- # Rufe die Kernlogik auf
341
- result = process_recipe_request_logic(
342
- required_ingredients, available_ingredients, max_ingredients_val, max_retries_val
343
- )
344
-
345
- if 'error' in result:
346
- return result['error'], "", "", ""
347
-
348
- ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']])
349
- directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])])
350
- used_ingredients = ', '.join(result['used_ingredients'])
351
-
352
- return (
353
- result['title'],
354
- ingredients_list,
355
- directions_list,
356
- used_ingredients
357
- )
358
-
359
- except Exception as e:
360
- # Fehlermeldung für die Gradio UI
361
- return f"Fehler: {str(e)}", "", "", ""
362
-
363
- # Erstelle die Gradio Oberfläche
364
- with gr.Blocks(title="AI Rezept Generator") as demo:
365
- gr.Markdown("# 🍳 AI Rezept Generator")
366
- gr.Markdown("Generiere Rezepte mit KI und intelligenter Zutat-Kombination!")
367
-
368
- with gr.Tab("Web-Oberfläche"):
369
- with gr.Row():
370
- with gr.Column():
371
- required_ing = gr.Textbox(
372
- label="Benötigte Zutaten (kommasepariert)",
373
- placeholder="Hähnchen, Reis, Zwiebel",
374
- lines=2
375
- )
376
- available_ing = gr.Textbox(
377
- label="Verfügbare Zutaten (kommasepariert, optional)",
378
- placeholder="Knoblauch, Tomate, Pfeffer, Kräuter",
379
- lines=2
380
- )
381
- max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximale Zutaten")
382
- max_retries = gr.Slider(1, 10, value=5, step=1, label="Max. Wiederholungsversuche")
383
-
384
- generate_btn = gr.Button("Rezept generieren", variant="primary")
385
 
386
- with gr.Column():
387
- title_output = gr.Textbox(label="Rezepttitel", interactive=False)
388
- ingredients_output = gr.Textbox(label="Zutaten", lines=8, interactive=False)
389
- directions_output = gr.Textbox(label="Anweisungen", lines=10, interactive=False)
390
- used_ingredients_output = gr.Textbox(label="Verwendete Zutaten", interactive=False)
 
 
391
 
392
- generate_btn.click(
393
- fn=gradio_ui_generate_recipe,
394
- inputs=[required_ing, available_ing, max_ing, max_retries],
395
- outputs=[title_output, ingredients_output, directions_output, used_ingredients_output]
396
- )
397
-
398
- with gr.Tab("API-Test"):
399
- gr.Markdown("### Teste die Flutter API (via 'hugging_face_chat_gradio' Client)")
400
- gr.Markdown("Dieser Tab zeigt, wie die Eingabe für die 'generate_recipe_for_flutter'-API aussehen sollte.")
401
-
402
- api_input = gr.Textbox(
403
- label="JSON-Eingabe (für API-Aufruf)",
404
- placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}',
405
- lines=4
406
- )
407
- api_output = gr.Textbox(label="JSON-Ausgabe", lines=15, interactive=False)
408
- api_test_btn = gr.Button("API testen", variant="secondary")
409
-
410
- api_test_btn.click(
411
- fn=flutter_api_generate_recipe,
412
- inputs=[api_input],
413
- outputs=[api_output],
414
- api_name="generate_recipe_for_flutter" # Dies ist der api_name, den das Flutter-Paket verwendet
415
- )
416
-
417
- gr.Examples(
418
- examples=[
419
- ['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'],
420
- ['{"ingredients": ["pasta"], "available_ingredients": ["cheese", "mushrooms", "cream"], "max_ingredients": 5}']
421
- ],
422
- inputs=[api_input]
423
- )
424
-
425
- # Gradio-App starten
426
- if __name__ == "__main__":
427
- demo.launch()
 
 
1
  from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
2
  import torch
3
  import numpy as np
4
  import random
5
+ import json
6
+ from fastapi import FastAPI
7
+ from fastapi.responses import JSONResponse
8
+ from pydantic import BaseModel
9
+ # Keine Gradio-Imports hier!
10
 
11
  # Lade RecipeBERT Modell (für semantische Zutat-Kombination)
12
  bert_model_name = "alexdseo/RecipeBERT"
13
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
14
  bert_model = AutoModel.from_pretrained(bert_model_name)
15
+ bert_model.eval()
16
 
17
  # Lade T5 Rezeptgenerierungsmodell
18
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
19
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
20
  t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
21
 
22
+ # Token Mapping (bleibt gleich)
23
  special_tokens = t5_tokenizer.all_special_tokens
24
  tokens_map = {
25
  "<sep>": "--",
26
  "<section>": "\n"
27
  }
28
 
29
+ # Deine Helper-Funktionen (get_embedding, average_embedding, get_cosine_similarity, etc.)
30
+ # ... diese bleiben ALLE GLEICH wie in deinem aktuellen app.py Code ...
31
+ # Kopiere alle Funktionen von 'get_embedding' bis 'generate_recipe_with_t5' hierher.
32
+ # (Ich kürze sie hier aus Platzgründen, aber sie müssen vollständig eingefügt werden)
33
+
34
  def get_embedding(text):
 
35
  inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
36
  with torch.no_grad():
37
  outputs = bert_model(**inputs)
 
 
38
  attention_mask = inputs['attention_mask']
39
  token_embeddings = outputs.last_hidden_state
40
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
41
  sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
42
  sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
 
43
  return (sum_embeddings / sum_mask).squeeze(0)
44
 
45
  def average_embedding(embedding_list):
 
46
  tensors = torch.stack([emb for _, emb in embedding_list])
47
  return tensors.mean(dim=0)
48
 
49
  def get_cosine_similarity(vec1, vec2):
50
+ if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
51
+ if torch.is_tensor(vec2): vec2 = vec2.detach().numpy()
 
 
 
 
 
52
  vec1 = vec1.flatten()
53
  vec2 = vec2.flatten()
 
54
  dot_product = np.dot(vec1, vec2)
55
  norm_a = np.linalg.norm(vec1)
56
  norm_b = np.linalg.norm(vec2)
57
+ if norm_a == 0 or norm_b == 0: return 0
 
 
 
 
58
  return dot_product / (norm_a * norm_b)
59
 
60
  def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
 
61
  results = []
 
62
  for name, emb in embedding_list:
 
63
  avg_similarity = get_cosine_similarity(query_vector, emb)
64
+ individual_similarities = [get_cosine_similarity(good_emb, emb) for _, good_emb in all_good_embeddings]
 
 
 
65
  avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
 
 
66
  combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
 
67
  results.append((name, emb, combined_score))
 
 
68
  results.sort(key=lambda x: x[2], reverse=True)
69
  return results
70
 
71
  def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
 
 
 
 
72
  required_ingredients = list(set(required_ingredients))
73
  available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
 
 
74
  if not required_ingredients and available_ingredients:
75
  random_ingredient = random.choice(available_ingredients)
76
  required_ingredients = [random_ingredient]
77
  available_ingredients = [i for i in available_ingredients if i != random_ingredient]
 
 
 
78
  if not required_ingredients or len(required_ingredients) >= max_ingredients:
79
  return required_ingredients[:max_ingredients]
 
 
80
  if not available_ingredients:
81
  return required_ingredients
 
 
82
  embed_required = [(e, get_embedding(e)) for e in required_ingredients]
83
  embed_available = [(e, get_embedding(e)) for e in available_ingredients]
 
 
84
  num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
 
 
85
  final_ingredients = embed_required.copy()
 
 
86
  for _ in range(num_to_add):
 
87
  avg = average_embedding(final_ingredients)
 
 
88
  candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
89
+ if not candidates: break
 
 
 
 
 
90
  best_name, best_embedding, _ = candidates[0]
 
 
91
  final_ingredients.append((best_name, best_embedding))
 
 
92
  embed_available = [item for item in embed_available if item[0] != best_name]
 
 
93
  return [name for name, _ in final_ingredients]
94
 
95
  def skip_special_tokens(text, special_tokens):
96
+ for token in special_tokens: text = text.replace(token, "")
 
 
97
  return text
98
 
99
  def target_postprocessing(texts, special_tokens):
100
+ if not isinstance(texts, list): texts = [texts]
 
 
 
101
  new_texts = []
102
  for text in texts:
103
  text = skip_special_tokens(text, special_tokens)
104
+ for k, v in tokens_map.items(): text = text.replace(k, v)
 
 
 
105
  new_texts.append(text)
 
106
  return new_texts
107
 
108
  def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
 
 
 
109
  recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
110
  expected_count = len(expected_ingredients)
111
  return abs(recipe_count - expected_count) == tolerance
112
 
113
  def generate_recipe_with_t5(ingredients_list, max_retries=5):
 
114
  original_ingredients = ingredients_list.copy()
 
115
  for attempt in range(max_retries):
116
  try:
 
117
  if attempt > 0:
118
  current_ingredients = original_ingredients.copy()
119
  random.shuffle(current_ingredients)
120
  else:
121
  current_ingredients = ingredients_list
 
 
122
  ingredients_string = ", ".join(current_ingredients)
123
  prefix = "items: "
 
 
124
  generation_kwargs = {
125
+ "max_length": 512, "min_length": 64, "do_sample": True,
126
+ "top_k": 60, "top_p": 0.95
 
 
 
127
  }
 
 
 
128
  inputs = t5_tokenizer(
129
+ prefix + ingredients_string, max_length=256, padding="max_length",
130
+ truncation=True, return_tensors="jax"
 
 
 
131
  )
 
 
132
  output_ids = t5_model.generate(
133
+ input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, **generation_kwargs
 
 
134
  )
 
 
135
  generated = output_ids.sequences
136
+ generated_text = target_postprocessing(t5_tokenizer.batch_decode(generated, skip_special_tokens=False), special_tokens)[0]
 
 
 
 
 
137
  recipe = {}
138
  sections = generated_text.split("\n")
139
  for section in sections:
 
146
  elif section.startswith("directions:"):
147
  directions_text = section.replace("directions:", "").strip()
148
  recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
 
 
149
  if "title" not in recipe:
150
  recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
 
 
151
  if "ingredients" not in recipe:
152
  recipe["ingredients"] = current_ingredients
153
  if "directions" not in recipe:
154
  recipe["directions"] = ["Keine Anweisungen generiert"]
 
 
155
  if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
 
156
  return recipe
157
  else:
158
+ if attempt == max_retries - 1: return recipe
 
 
 
 
159
  except Exception as e:
 
160
  if attempt == max_retries - 1:
161
  return {
162
  "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
163
  "ingredients": original_ingredients,
164
  "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
165
  }
 
 
166
  return {
167
  "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
168
  "ingredients": original_ingredients,
169
  "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
170
  }
171
 
 
 
172
  def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
173
  """
174
  Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
 
175
  """
176
  if not required_ingredients and not available_ingredients:
177
  return {"error": "Keine Zutaten angegeben"}
 
178
  try:
 
179
  optimized_ingredients = find_best_ingredients(
180
+ required_ingredients, available_ingredients, max_ingredients
 
 
181
  )
 
 
182
  recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
 
 
183
  result = {
184
  'title': recipe['title'],
185
  'ingredients': recipe['ingredients'],
 
187
  'used_ingredients': optimized_ingredients
188
  }
189
  return result
 
190
  except Exception as e:
191
  return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
192
 
193
+ # --- FastAPI-Implementierung ---
194
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ class RecipeRequest(BaseModel):
197
+ required_ingredients: list[str] = []
198
+ available_ingredients: list[str] = []
199
+ max_ingredients: int = 7
200
+ max_retries: int = 5
201
+ # Optional: Für Abwärtskompatibilität, falls 'ingredients' als Top-Level-Feld gesendet wird
202
+ # ingredients: list[str] = [] # Dies würde auch akzeptiert und müsste dann in der Logik verarbeitet werden
203
 
204
+ @app.post("/generate_recipe") # Einfacher Endpunkt, den Flutter aufruft
205
+ async def generate_recipe_api(request_data: RecipeRequest):
206
+ """
207
+ Standard-REST-API-Endpunkt für die Flutter-App.
208
+ Nimmt direkt JSON-Daten an und gibt direkt JSON zurück.
209
+ """
210
+ # Verarbeite optionale Abwärtskompatibilität hier, falls nötig
211
+ if not request_data.required_ingredients and 'ingredients' in request_data.model_dump():
212
+ request_data.required_ingredients = request_data.model_dump()['ingredients']
213
+
214
+ result_dict = process_recipe_request_logic(
215
+ request_data.required_ingredients,
216
+ request_data.available_ingredients,
217
+ request_data.max_ingredients,
218
+ request_data.max_retries
219
+ )
220
+ return JSONResponse(content=result_dict)
221
+
222
+ # In diesem Setup gibt es keine Gradio UI, nur die FastAPI-API.
223
+ # Dadurch sollte der Space zuverlässiger starten.
224
+
225
+ # Der if __name__ == "__main__": Block wird von Hugging Face Spaces ignoriert,
226
+ # da sie den Uvicorn-Server direkt starten, der die 'app'-Variable sucht.