Spaces:
Runtime error
Runtime error
import logging | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
import torch | |
import spaces | |
########################## | |
# CONFIGURATION | |
########################## | |
logging.basicConfig( | |
level=logging.getLevelName("INFO"), | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
# Example images and texts | |
EXAMPLES = [ | |
["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."], | |
["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."], | |
# ["images/ingredients_3.jpg", "tural basmati rice - cooked (98%), rice bran oil, salt"], | |
["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"], | |
["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."], | |
] | |
MODEL_ID = "openfoodfacts/spellcheck-mistral-7b" | |
PRESENTATION = """# 🍊 Ingredients Spellcheck - Open Food Facts | |
Open Food Facts is a non-profit organization building the largest open food database in the world. 🌎 | |
When a product is added to the database, all its details, such as allergens, additives, or nutritional values, are either wrote down by the contributor, | |
or automatically extracted from the product pictures using OCR. | |
However, it often happens the information extracted by OCR contains typos and errors due to bad quality pictures: low-definition, curved product, light reflection, etc... | |
To solve this problem, we developed an 🍊 **Ingredient Spellcheck** 🍊, a model capable of correcting typos in a list of ingredients following a defined guideline. | |
The model, based on Mistral-7B-v0.3, was fine-tuned on thousand of corrected lists of ingredients extracted from the database. More information in the model card. | |
### *Project in progress* 🏗️ | |
## 👇 Links | |
* Open Food Facts website: https://world.openfoodfacts.org/discover | |
* Open Food Facts Github: https://github.com/openfoodfacts | |
* Spellcheck project: https://github.com/openfoodfacts/openfoodfacts-ai/tree/develop/spellcheck | |
* Model card: https://huggingface.co/openfoodfacts/spellcheck-mistral-7b | |
""" | |
# CPU/GPU device | |
zero = torch.Tensor([0]).cuda() | |
# Transformers seed to orient generation to be reproducible (as possible since it doesn't ensure 100% reproducibility) | |
set_seed(42) | |
########################## | |
# LOADING | |
########################## | |
# Tokenizer | |
logging.info(f"Load tokenizer from {MODEL_ID}.") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
# Model | |
logging.info(f"Load model from {MODEL_ID}.") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="auto", | |
# attn_implementation="flash_attention_2", # Not supported by ZERO GPU | |
# torch_dtype=torch.bfloat16, | |
) | |
########################## | |
# FUNCTIONS | |
########################## | |
def process(text: str) -> str: | |
"""Take the text, the tokenizer and the causal model and generate the correction.""" | |
prompt = prepare_instruction(text) | |
input_ids = tokenizer( | |
prompt, | |
add_special_tokens=True, | |
return_tensors="pt" | |
).input_ids | |
output = model.generate( | |
input_ids.to(zero.device), # GPU | |
do_sample=False, | |
max_new_tokens=512, | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip() | |
def prepare_instruction(text: str) -> str: | |
"""Prepare instruction prompt for fine-tuning and inference. | |
Identical to instruction during training. | |
Args: | |
text (str): List of ingredients | |
Returns: | |
str: Instruction. | |
""" | |
instruction = ( | |
"###Correct the list of ingredients:\n" | |
+ text | |
+ "\n\n###Correction:\n" | |
) | |
return instruction | |
########################## | |
# GRADIO SETUP | |
########################## | |
with gr.Blocks() as demo: | |
gr.Markdown(PRESENTATION) | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(type="pil", label="image_input", interactive=False) | |
with gr.Column(): | |
ingredients = gr.Textbox(label="List of ingredients") | |
spellcheck_button = gr.Button(value='Run spellcheck') | |
correction = gr.Textbox(label="Correction", interactive=False) | |
with gr.Row(): | |
gr.Examples( | |
fn=process, | |
examples=EXAMPLES, | |
inputs=[ | |
image, | |
ingredients, | |
], | |
) | |
spellcheck_button.click( | |
fn=process, | |
inputs=[ingredients], | |
outputs=[correction] | |
) | |
if __name__ == "__main__": | |
# Launch the demo | |
demo.launch() | |