update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
3 |
-
import csv
|
4 |
|
5 |
MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
|
6 |
WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"
|
@@ -22,18 +22,20 @@ categories = {
|
|
22 |
}
|
23 |
|
24 |
def prediction(news):
|
25 |
-
|
26 |
-
preds =
|
27 |
preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
|
28 |
return preds_dict
|
29 |
|
30 |
def file_prediction(file):
|
|
|
|
|
31 |
if file.name.endswith('.csv'):
|
32 |
-
file.seek(0)
|
33 |
reader = csv.reader(file.read().decode('utf-8').splitlines())
|
34 |
-
news_list = [row[0] for row in reader if row]
|
35 |
else:
|
36 |
-
file.seek(0)
|
37 |
file_content = file.read().decode('utf-8')
|
38 |
news_list = file_content.splitlines()
|
39 |
|
@@ -41,9 +43,9 @@ def file_prediction(file):
|
|
41 |
for news in news_list:
|
42 |
if news.strip():
|
43 |
pred = prediction(news)
|
44 |
-
results.append([news, pred])
|
45 |
|
46 |
-
return results
|
47 |
|
48 |
gradio_ui = gr.Interface(
|
49 |
fn=prediction,
|
@@ -67,4 +69,3 @@ gradio_file_ui = gr.Interface(
|
|
67 |
gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
|
68 |
|
69 |
gradio_combined_ui.launch()
|
70 |
-
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
3 |
+
import csv
|
4 |
|
5 |
MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
|
6 |
WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"
|
|
|
22 |
}
|
23 |
|
24 |
def prediction(news):
|
25 |
+
classifier = pipeline("text-classification", tokenizer=tokenizer, model=model, return_all_scores=True)
|
26 |
+
preds = classifier(news)
|
27 |
preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
|
28 |
return preds_dict
|
29 |
|
30 |
def file_prediction(file):
|
31 |
+
news_list = []
|
32 |
+
|
33 |
if file.name.endswith('.csv'):
|
34 |
+
file.seek(0)
|
35 |
reader = csv.reader(file.read().decode('utf-8').splitlines())
|
36 |
+
news_list = [row[0] for row in reader if row]
|
37 |
else:
|
38 |
+
file.seek(0)
|
39 |
file_content = file.read().decode('utf-8')
|
40 |
news_list = file_content.splitlines()
|
41 |
|
|
|
43 |
for news in news_list:
|
44 |
if news.strip():
|
45 |
pred = prediction(news)
|
46 |
+
results.append([news, pred]) # Return each news and its prediction as a row
|
47 |
|
48 |
+
return results # Gradio expects a list of lists or dicts for DataFrame
|
49 |
|
50 |
gradio_ui = gr.Interface(
|
51 |
fn=prediction,
|
|
|
69 |
gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
|
70 |
|
71 |
gradio_combined_ui.launch()
|
|