hugging2021 commited on
Commit
29879ee
·
verified ·
1 Parent(s): 3db6291

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -1,42 +1,47 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- pipe = pipeline("fill-mask", model="google-bert/bert-base-uncased")
 
 
 
5
 
6
- title = "BERT"
7
- description = "Gradio Demo for BERT. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
8
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1810.04805' target='_blank'>BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a></p>"
 
 
 
 
 
9
 
10
- examples = [
11
- ['Paris is the [MASK] of France.', 'bert-base-cased']
12
- ]
 
13
 
14
- # Lade die Interfaces für die Modelle
15
- io1 = gr.Interface.load("huggingface/bert-base-cased")
16
- io2 = gr.Interface.load("huggingface/bert-base-uncased")
17
 
18
- def inference(inputtext, model):
19
- if "[MASK]" not in inputtext:
20
- return {"error": "The input text must contain the [MASK] token."}
21
 
22
- if model == "bert-base-cased":
23
- return io1(inputtext)
24
- elif model == "bert-base-uncased":
25
- return io2(inputtext)
26
- else:
27
- return {"error": "Invalid model selected"}
28
 
 
29
  iface = gr.Interface(
30
  fn=inference,
31
- inputs=[
32
- gr.Textbox(label="Context", lines=10, placeholder="Enter text with [MASK] token"),
33
- gr.Dropdown(choices=["bert-base-cased", "bert-base-uncased"], value="bert-base-cased", label="model")
34
- ],
35
- outputs=gr.JSON(label="Output"), # We use JSON to display errors or outputs
36
- examples=examples,
37
- article=article,
38
- title=title,
39
- description=description
40
  )
41
 
42
- iface.launch(share=True)
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import BertForMaskedLM, BertTokenizer
4
 
5
+ # Modell und Tokenizer laden
6
+ model_name = "bert-base-uncased"
7
+ model = BertForMaskedLM.from_pretrained(model_name)
8
+ tokenizer = BertTokenizer.from_pretrained(model_name)
9
 
10
+ # Inferenz-Funktion definieren
11
+ def inference(input_text):
12
+ if "[MASK]" not in input_text:
13
+ return "Error: The input text must contain the [MASK] token."
14
+
15
+ # Tokenisierung
16
+ inputs = tokenizer(input_text, return_tensors="pt")
17
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
18
 
19
+ # Vorhersage
20
+ with torch.no_grad():
21
+ outputs = model(**inputs)
22
+ logits = outputs.logits
23
 
24
+ # Wahrscheinlichsten Token für [MASK] finden
25
+ mask_token_logits = logits[0, mask_token_index, :]
26
+ top_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()
27
 
28
+ # Vorhersage in den Text einfügen
29
+ predicted_token = tokenizer.decode(top_token)
30
+ result_text = input_text.replace("[MASK]", predicted_token, 1)
31
 
32
+ return result_text
 
 
 
 
 
33
 
34
+ # Gradio Interface definieren
35
  iface = gr.Interface(
36
  fn=inference,
37
+ inputs="text",
38
+ outputs="text",
39
+ examples=[
40
+ ["The capital of France is [MASK]."],
41
+ ["The quick brown fox jumps over the [MASK] dog."]
42
+ ]
 
 
 
43
  )
44
 
45
+ # Interface starten
46
+ if __name__ == "__main__":
47
+ iface.launch()