File size: 1,452 Bytes
3a077df 29879ee 5eefde1 cbef7cc 29879ee cbef7cc 960bb18 29879ee 960bb18 29879ee 960bb18 29879ee ce61f52 29879ee a623e87 29879ee 15d1762 29879ee e902865 120b77f 29879ee e902865 29879ee 0c194be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import gradio as gr
import torch
from transformers import BertForMaskedLM, BertTokenizer
# Modell und Tokenizer laden mit force_download=True
model_name = "bert-base-uncased"
model = BertForMaskedLM.from_pretrained(model_name, force_download=True)
tokenizer = BertTokenizer.from_pretrained(model_name, force_download=True)
# Inferenz-Funktion definieren
def inference(input_text):
if "[MASK]" not in input_text:
return "Error: The input text must contain the [MASK] token."
# Tokenisierung
inputs = tokenizer(input_text, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
# Vorhersage
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Wahrscheinlichsten Token für [MASK] finden
mask_token_logits = logits[0, mask_token_index, :]
top_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()
# Vorhersage in den Text einfügen
predicted_token = tokenizer.decode(top_token)
result_text = input_text.replace("[MASK]", predicted_token, 1)
return result_text
# Gradio Interface definieren
iface = gr.Interface(
fn=inference,
inputs="text",
outputs="text",
examples=[
["The capital of France is [MASK]."],
["The quick brown fox jumps over the [MASK] dog."]
]
)
# Interface starten
if __name__ == "__main__":
iface.launch(server_port=7862) |