File size: 1,910 Bytes
5b371ce
 
 
 
 
 
 
 
 
 
8f3f1c2
d37a37c
 
 
 
3a8735f
 
d37a37c
 
 
 
 
 
 
3a8735f
c8e91cc
 
3a8735f
d37a37c
 
 
3a8735f
8d5bf2a
2fdc797
c57666d
d37a37c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import re
import gradio as gr
from transformers import (
    AutoModelForSequenceClassification, 
    AutoTokenizer, 
    pipeline
)
from transformers_interpret import SequenceClassificationExplainer
from hebrewtools.functions import sbl_normalization

model_name = 'martijn75/COHeN_2.0_10_epochs'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
cls_explainer = SequenceClassificationExplainer(model, tokenizer)

pipe = pipeline("text-classification", model=model_name)

pattern = re.compile("[^\s\u05d0-\u05ea\u05b0-\u05bc\u05be\u05c1\u05c2\u05c7]")

def predict(text):
  text = " ".join([word for word in text.split() if word not in ['\u05e1', '\u05e4', '']])
  text = re.sub(pattern, "", text)
  text = sbl_normalization(text)
  word_attributions = cls_explainer(text)
  results = pipe(text)[0]
  label_keys = {'LABEL_0' : 'ABH', 'LABEL_1' : 'CBH', 'LABEL_2' : 'TBH', 'LABEL_3' : 'LBH'}
  label = f"{label_keys[results['label']]} ({results['score']:.2})"
  return label, word_attributions[1:-1]

iface = gr.Interface(
    fn=predict,
    inputs=gr.Text(label="Input Text"),
    outputs=[gr.Text(label="Label"), gr.HighlightedText(label="Word Importance", show_legend=True, color_map={"-": "red", "+": "blue"})],
    theme=gr.themes.Base(),
    examples=[['וְסָפְדָה הָאָרֶץ מִשְׁפָּחוֺת מִשְׁפָּחוֺת לְבָד מִשְׁפַּחַת בֵּית־דָּוִיד לְבָד וּנְשֵׁיהֶם לְבָד מִשְׁפַּחַת בֵּית־נָתָן לְבָד וּנְשֵׁיהֶם לְבָד'], ['וַיֹּאמֶר דָּוִד אֶל־אוּרִיָּה שֵׁב בָּזֶה גַּם־הַיּוֺם וּמָחָר אֲשַׁלְּחֶךָּ וַיֵּשֶׁב אוּרִיָּה בִירוּשָׁלִַם בַּיּוֺם הַהוּא וּמִמָּחֳרָת']]
)

iface.launch()