|
from flask import Flask, request |
|
import requests |
|
import os |
|
import re |
|
import textwrap |
|
from transformers import AutoModelForSeq2SeqLM |
|
from transformers import AutoTokenizer |
|
from langdetect import detect |
|
import subprocess |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") |
|
|
|
vn_tokenizer = AutoTokenizer.from_pretrained("GuysTrans/bart-base-vn-ehealth-vn-tokenizer") |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
"GuysTrans/bart-base-finetuned-xsum", revision="worked") |
|
|
|
vn_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
"GuysTrans/bart-base-vn-ehealth-vn-tokenizer", revision="worked") |
|
|
|
map_words = { |
|
"Hello and Welcome to 'Ask A Doctor' service": "", |
|
"Hello,": "", |
|
"Hi,": "", |
|
"Hello": "", |
|
"Hi": "", |
|
"Ask A Doctor": "MedForum", |
|
"H C M": "Med Forum" |
|
} |
|
|
|
word_remove_sentence = [ |
|
"Welcome to", |
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
def generate_summary(question, model, tokenizer): |
|
inputs = tokenizer( |
|
question, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=512, |
|
return_tensors="pt", |
|
) |
|
input_ids = inputs.input_ids.to(model.device) |
|
attention_mask = inputs.attention_mask.to(model.device) |
|
outputs = model.generate( |
|
input_ids, attention_mask=attention_mask, max_new_tokens=4096, do_sample=True, num_beams=4, top_k=50, early_stopping=True, no_repeat_ngram_size=2) |
|
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
return outputs, output_str |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
FB_API_URL = 'https://graph.facebook.com/v2.6/me/messages' |
|
VERIFY_TOKEN = '5rApTs/BRm6jtiwApOpIdjBHe73ifm6mNGZOsYkwwAw=' |
|
|
|
PAGE_ACCESS_TOKEN = os.environ['PAGE_ACCESS_TOKEN'] |
|
|
|
|
|
def get_bot_response(message): |
|
lang = detect(message) |
|
model_use = model |
|
tokenizer_use = tokenizer |
|
template = "Welcome to MedForRum chatbot service. %s. Thanks for asking on MedForum." |
|
if lang == "vi": |
|
model_use = vn_model |
|
tokenizer_use = vn_tokenizer |
|
template = "Chào mừng bạn đến với dịch vụ MedForRum chatbot. %s. Cảm ơn bạn đã sử dụng MedForum." |
|
return template % post_process(generate_summary(message, model_use, tokenizer_use)[1][0]) |
|
|
|
|
|
def verify_webhook(req): |
|
if req.args.get("hub.verify_token") == VERIFY_TOKEN: |
|
return req.args.get("hub.challenge") |
|
else: |
|
return "incorrect" |
|
|
|
|
|
def respond(sender, message): |
|
"""Formulate a response to the user and |
|
pass it on to a function that sends it.""" |
|
response = get_bot_response(message) |
|
send_message(sender, response) |
|
return response |
|
|
|
|
|
def is_user_message(message): |
|
"""Check if the message is a message from the user""" |
|
return (message.get('message') and |
|
message['message'].get('text') and |
|
not message['message'].get("is_echo")) |
|
|
|
|
|
@app.route("/webhook", methods=['GET', 'POST']) |
|
def listen(): |
|
"""This is the main function flask uses to |
|
listen at the `/webhook` endpoint""" |
|
if request.method == 'GET': |
|
return verify_webhook(request) |
|
|
|
if request.method == 'POST': |
|
payload = request.json |
|
event = payload['entry'][0]['messaging'] |
|
for x in event: |
|
if is_user_message(x): |
|
text = x['message']['text'] |
|
sender_id = x['sender']['id'] |
|
respond(sender_id, text) |
|
|
|
return "ok" |
|
|
|
|
|
def send_message(recipient_id, text): |
|
"""Send a response to Facebook""" |
|
payload = { |
|
'message': { |
|
'text': text |
|
}, |
|
'recipient': { |
|
'id': recipient_id |
|
}, |
|
'notification_type': 'regular' |
|
} |
|
|
|
auth = { |
|
'access_token': PAGE_ACCESS_TOKEN |
|
} |
|
|
|
response = requests.post( |
|
FB_API_URL, |
|
params=auth, |
|
json=payload |
|
) |
|
|
|
return response.json() |
|
|
|
|
|
@app.route("/webhook/chat", methods=['POST']) |
|
def chat(): |
|
payload = request.json |
|
message = payload['message'] |
|
response = get_bot_response(message) |
|
return {"message": response} |
|
|
|
def post_process(output): |
|
|
|
|
|
lines = output.split(".") |
|
for line in lines: |
|
for word in word_remove_sentence: |
|
if word.lower() in line.lower(): |
|
lines.remove(line) |
|
break |
|
|
|
output = ".".join(lines) |
|
for item in map_words.keys(): |
|
output = re.sub(item, map_words[item], output, re.I) |
|
|
|
return textwrap.fill(textwrap.dedent(output).strip(), width=120) |
|
|
|
|
|
|
|
subprocess.Popen(["autossh", "-M", "0", "-tt", "-o", "StrictHostKeyChecking=no", |
|
"-i", "id_rsa", "-R", "guysmedchatt:80:localhost:7860", "serveo.net"]) |
|
|
|
|