import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import re from resources import banner, error_html_response import logging logging.basicConfig(format='%(asctime)s: [%(levelname)s]: %(message)s', level=logging.INFO) model_checkpoint = 'gastronomia-para-to2/gastronomia_para_to2' tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForCausalLM.from_pretrained(model_checkpoint) special_tokens = [ '', '', '', '', '', '', '', '', '', '', ''] def frame_html_response(html_response): return f"""""" def check_special_tokens_order(pre_output): return (pre_output.find('') < pre_output.find('') <= pre_output.rfind('') < pre_output.find('') < pre_output.find('') < pre_output.find('') <= pre_output.rfind('') < pre_output.find('') < pre_output.find('') < pre_output.find('') <= pre_output.rfind('') < pre_output.find('') < pre_output.find('') < pre_output.find('')) def make_html_response(title, ingredients, instructions): ingredients_html_list = '
  • ' + '
  • '.join(ingredients) + '
' instructions_html_list = '
  1. ' + '
  2. '.join(instructions) + '
' html_response = f'''

{title}

Ingredientes

{ingredients_html_list}

Instrucciones

{instructions_html_list} ''' return html_response def rerun_model_output(pre_output): if pre_output is None: return True elif not '' in pre_output: logging.info(' not in pre_output') return True pre_output_trimmed = pre_output[:pre_output.find('')] if not all(special_token in pre_output_trimmed for special_token in special_tokens): logging.info('Not all special tokens are in preoutput') return True elif not check_special_tokens_order(pre_output_trimmed): logging.info('Special tokens are unordered in preoutput') return True elif len(pre_output_trimmed.split())<75: logging.info('Length of the recipe is <75') return True else: return False def check_wrong_ingredients(ingredients): new_ingredients = [] for ingredient in ingredients: if ingredient.startswith('De '): new_ingredients.append(ingredient.strip('De ').capitalize()) else: new_ingredients.append(ingredient) return new_ingredients def make_recipe(input_ingredients): logging.info(f'Received inputs: {input_ingredients}') input_ingredients = re.sub(' y ', ', ', input_ingredients) input = ' ' input += ' ' + ' '.join(input_ingredients.split(', ')) + ' ' input += ' ' tokenized_input = tokenizer(input, return_tensors='pt') pre_output = None i = 0 while rerun_model_output(pre_output): if i == 3: return frame_html_response(error_html_response) output = model.generate(**tokenized_input, max_length=600, do_sample=True, top_p=0.92, top_k=50, # no_repeat_ngram_size=3, num_return_sequences=3) pre_output = tokenizer.decode(output[0], skip_special_tokens=False) i += 1 pre_output_trimmed = pre_output[:pre_output.find('')] output_ingredients = re.search(' (.*) ', pre_output_trimmed).group(1) output_ingredients = output_ingredients.split(' ') output_ingredients = list(set([output_ingredient.strip() for output_ingredient in output_ingredients])) output_ingredients = [output_ing.capitalize() for output_ing in output_ingredients] output_ingredients = check_wrong_ingredients(output_ingredients) output_title = re.search(' (.*) ', pre_output_trimmed).group(1).strip().capitalize() output_instructions = re.search(' (.*) ', pre_output_trimmed).group(1) output_instructions = output_instructions.split(' ') html_response = make_html_response(output_title, output_ingredients, output_instructions) return frame_html_response(html_response) iface = gr.Interface( fn=make_recipe, inputs= [ gr.inputs.Textbox(lines=1, placeholder='ingrediente_1, ingrediente_2, ..., ingrediente_n', label='Dime con qué ingredientes quieres que cocinemos hoy y te sugeriremos una receta tan pronto como nuestros fogones estén libres'), ], outputs= [ gr.outputs.HTML(label="¡Esta es mi propuesta para ti! ¡Buen provecho!") ], examples= [ ['salmón, zumo de naranja, aceite de oliva, sal, pimienta'], ['harina, azúcar, huevos, chocolate, levadura Royal'] ], description=banner) iface.launch(enable_queue=True)