Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from datasets import load_dataset | |
from transformers import pipeline | |
from textwrap import dedent | |
from email import message_from_file | |
from email.header import decode_header | |
# select device | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# load model | |
pipe = pipeline(model="1aurent/distilbert-base-multilingual-cased-finetuned-email-spam", device=device) | |
# fn to predict from text | |
def classify_raw(text): | |
return pipe(text, top_k=2) | |
# fn to predict from form inputs | |
def classify_form(mailfrom, x_mailfrom, to, reply_to, subject): | |
text = dedent(f""" | |
From: {mailfrom} | |
X-MailFrom: {x_mailfrom} | |
To: {to} | |
Reply-To: {reply_to} | |
Subject: {subject} | |
""").strip() | |
return pipe(text, top_k=2) | |
# helper to extract header from email | |
def get_header(message, header_name: str) -> str: | |
try: | |
for payload, _ in decode_header(message[header_name]): | |
if type(payload) == bytes: | |
payload = payload.decode(errors="ignore") | |
header = payload | |
header = header.replace("\n", " ") | |
header = header.strip() | |
return header | |
except: | |
return "" | |
# fn to predict from email file | |
def classify_file(file): | |
message = message_from_file(open(file.name)) | |
return classify_form( | |
mailfrom=get_header(message, "From"), | |
x_mailfrom=get_header(message, "X-MailFrom"), | |
to=get_header(message, "To"), | |
reply_to=get_header(message, "Reply-To"), | |
subject=get_header(message, "Subject"), | |
) | |
title = "Email Spam Classifier" | |
description = """ | |
Spam or ham ? | |
""" | |
demo = gr.Blocks() | |
raw_interface = gr.Interface( | |
fn=classify_raw, | |
inputs=gr.Textbox( | |
label="Formatted Email Header", | |
lines=5, | |
placeholder=dedent(""" | |
From: Laurent Fainsin <[email protected]> | |
X-MailFrom: Laurent Fainsin <[email protected]> | |
To: net7 <[email protected]> | |
Reply-To: Laurent Fainsin <[email protected]> | |
Subject: Re: Demande d'un H24 net7 | |
""").strip(), | |
), | |
outputs="json", | |
api_name="predict_raw_text", | |
) | |
form_interface = gr.Interface( | |
fn=classify_form, | |
inputs=[ | |
gr.Textbox( | |
label="From", | |
placeholder="Laurent Fainsin <[email protected]>", | |
), | |
gr.Textbox( | |
label="X-MailFrom", | |
placeholder="Laurent Fainsin <[email protected]>", | |
), | |
gr.Textbox( | |
label="To", | |
placeholder="net7 <[email protected]>", | |
), | |
gr.Textbox( | |
label="Reply-To", | |
placeholder="Laurent Fainsin <[email protected]>", | |
), | |
gr.Textbox( | |
label="Subject", | |
placeholder="Re: Demande d'un H24 net7", | |
), | |
], | |
outputs="json", | |
api_name="predict_form", | |
) | |
file_interface = gr.Interface( | |
fn=classify_file, | |
inputs=gr.File( | |
label="Email File", | |
file_types=[".eml"], | |
), | |
outputs="json", | |
api_name="predict_file", | |
) | |
with demo: | |
gr.TabbedInterface( | |
interface_list=[ | |
raw_interface, | |
form_interface, | |
file_interface | |
], | |
tab_names=[ | |
"Raw Text", | |
"Form", | |
"File" | |
] | |
) | |
demo.launch() | |