Commit
·
90ed2ce
1
Parent(s):
c21078f
fix: Type hints
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
|
|
7 |
from transformers import pipeline
|
8 |
from shap import Explainer
|
9 |
import numpy as np
|
|
|
10 |
|
11 |
|
12 |
def main():
|
@@ -27,8 +28,8 @@ def main():
|
|
27 |
"😊",
|
28 |
]
|
29 |
|
30 |
-
def classification(text) ->
|
31 |
-
output:
|
32 |
print(output)
|
33 |
|
34 |
explainer = Explainer(pipe)
|
@@ -40,7 +41,7 @@ def main():
|
|
40 |
if np.abs(shap_values).max() <= boundary:
|
41 |
boundary = np.abs(shap_values).max() - 1e-6
|
42 |
|
43 |
-
words:
|
44 |
records = list()
|
45 |
char_idx = 0
|
46 |
for word, shap_value in zip(words, shap_values):
|
|
|
7 |
from transformers import pipeline
|
8 |
from shap import Explainer
|
9 |
import numpy as np
|
10 |
+
from typing import Tuple, Dict, List
|
11 |
|
12 |
|
13 |
def main():
|
|
|
28 |
"😊",
|
29 |
]
|
30 |
|
31 |
+
def classification(text) -> Tuple[Dict[str, float], dict]:
|
32 |
+
output: List[dict] = pipe(text)[0]
|
33 |
print(output)
|
34 |
|
35 |
explainer = Explainer(pipe)
|
|
|
41 |
if np.abs(shap_values).max() <= boundary:
|
42 |
boundary = np.abs(shap_values).max() - 1e-6
|
43 |
|
44 |
+
words: List[str] = explanation.data[0]
|
45 |
records = list()
|
46 |
char_idx = 0
|
47 |
for word, shap_value in zip(words, shap_values):
|