hugging2021 commited on
Commit
b75af72
·
verified ·
1 Parent(s): 547ac8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py CHANGED
@@ -1,3 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline
2
  unmasker = pipeline('fill-mask', model='bert-base-uncased')
3
  unmasker("Hello I'm a [MASK] model.")
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import BertForMaskedLM, BertTokenizer
4
+ import asyncio
5
+
6
+ # Modell und Tokenizer laden mit force_download=True
7
+ model_name = "bert-base-uncased"
8
+ model = BertForMaskedLM.from_pretrained(model_name, force_download=True)
9
+ tokenizer = BertTokenizer.from_pretrained(model_name, force_download=True)
10
+
11
+ # Inferenz-Funktion definieren
12
+ def inference(input_text):
13
+ if "[MASK]" not in input_text:
14
+ return "Error: The input text must contain the [MASK] token."
15
+
16
+ # Tokenisierung
17
+ inputs = tokenizer(input_text, return_tensors="pt")
18
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
19
+
20
+ # Vorhersage
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
+ logits = outputs.logits
24
+
25
+ # Wahrscheinlichsten Token für [MASK] finden
26
+ mask_token_logits = logits[0, mask_token_index, :]
27
+ top_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()
28
+
29
+ # Vorhersage in den Text einfügen
30
+ predicted_token = tokenizer.decode(top_token)
31
+ result_text = input_text.replace("[MASK]", predicted_token, 1)
32
+
33
+ return result_text
34
+
35
+ # Gradio Interface definieren
36
+ iface = gr.Interface(
37
+ fn=inference,
38
+ inputs="text",
39
+ outputs="text",
40
+ examples=[
41
+ ["The capital of France is [MASK]."],
42
+ ["The quick brown fox jumps over the [MASK] dog."]
43
+ ]
44
+ )
45
+
46
+ # Interface starten
47
+ if __name__ == "__main__":
48
+ # Asynchronen Ereignisloop manuell erstellen und zuweisen
49
+ loop = asyncio.new_event_loop()
50
+ asyncio.set_event_loop(loop)
51
+
52
+ iface.launch(server_port=7862)
53
+
54
+
55
  from transformers import pipeline
56
  unmasker = pipeline('fill-mask', model='bert-base-uncased')
57
  unmasker("Hello I'm a [MASK] model.")