|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
from arabert.preprocess import ArabertPreprocessor |
|
import unicodedata |
|
import arabic_reshaper |
|
from bidi.algorithm import get_display |
|
import torch |
|
import random |
|
import re |
|
import gradio as gr |
|
|
|
tokenizer1 = AutoTokenizer.from_pretrained("Reham721/Subjective_QG") |
|
tokenizer2 = AutoTokenizer.from_pretrained("google/mt5-base") |
|
|
|
model1 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/Subjective_QG") |
|
model2 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/MCQs_QG") |
|
|
|
prep = ArabertPreprocessor("aubmindlab/araelectra-base-discriminator") |
|
qa_pipe = pipeline("question-answering", model="wissamantoun/araelectra-base-artydiqa") |
|
|
|
def generate_questions(model, tokenizer, input_sequence): |
|
input_ids = tokenizer.encode(input_sequence, return_tensors='pt') |
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_length=200, |
|
num_beams=3, |
|
no_repeat_ngram_size=3, |
|
early_stopping=True, |
|
temperature=1, |
|
num_return_sequences=3, |
|
) |
|
return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
|
|
|
def get_sorted_questions(questions, context): |
|
dic = {} |
|
context = prep.preprocess(context) |
|
for question in questions: |
|
try: |
|
result = qa_pipe(question=question, context=context) |
|
dic[question] = result["score"] |
|
except: |
|
dic[question] = 0 |
|
return dict(sorted(dic.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
def is_arabic(text): |
|
reshaped_text = arabic_reshaper.reshape(text) |
|
bidi_text = get_display(reshaped_text) |
|
for char in bidi_text: |
|
if char.isalpha() and not unicodedata.name(char).startswith('ARABIC'): |
|
return False |
|
return True |
|
|
|
def generate_distractors(question, answer, context, num_distractors=3, k=10): |
|
input_sequence = f'{question} <sep> {answer} <sep> {context}' |
|
input_ids = tokenizer2.encode(input_sequence, return_tensors='pt') |
|
outputs = model2.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_length=50, |
|
top_k=50, |
|
top_p=0.95, |
|
num_return_sequences=num_distractors, |
|
no_repeat_ngram_size=2 |
|
) |
|
distractors = [] |
|
for output in outputs: |
|
decoded_output = tokenizer2.decode(output, skip_special_tokens=True) |
|
elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e] |
|
elements = [e for e in elements if e and is_arabic(e)] |
|
distractors.extend(elements) |
|
unique_distractors = [] |
|
for d in distractors: |
|
if d not in unique_distractors and d != answer: |
|
unique_distractors.append(d) |
|
while len(unique_distractors) < num_distractors: |
|
outputs = model2.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_length=50, |
|
top_k=50, |
|
top_p=0.95, |
|
num_return_sequences=num_distractors - len(unique_distractors), |
|
no_repeat_ngram_size=2 |
|
) |
|
for output in outputs: |
|
decoded_output = tokenizer2.decode(output, skip_special_tokens=True) |
|
elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e] |
|
elements = [e for e in elements if e and is_arabic(e)] |
|
for e in elements: |
|
if e not in unique_distractors and e != answer: |
|
unique_distractors.append(e) |
|
if len(unique_distractors) >= num_distractors: |
|
break |
|
if len(unique_distractors) > k: |
|
unique_distractors = sorted(unique_distractors, key=lambda x: random.random())[:k] |
|
return random.sample(unique_distractors, num_distractors) |
|
|
|
context = gr.Textbox(lines=5, placeholder="أدخل الفقرة هنا", label="النص") |
|
answer = gr.Textbox(lines=3, placeholder="أدخل الإجابة هنا", label="الإجابة") |
|
question_type = gr.Radio(choices=["سؤال مقالي", "سؤال اختيار من متعدد"], label="نوع السؤال") |
|
question = gr.Textbox(type="text", label="السؤال الناتج") |
|
|
|
def generate_question(context, answer, question_type): |
|
article = answer + "<sep>" + context |
|
output = generate_questions(model1, tokenizer1, article) |
|
result = get_sorted_questions(output, context) |
|
best_question = next(iter(result)) if result else "لم يتم توليد سؤال مناسب" |
|
if question_type == "سؤال مقالي": |
|
return best_question |
|
else: |
|
mcqs = generate_distractors(best_question, answer, context) |
|
mcqs.append(answer) |
|
random.shuffle(mcqs) |
|
return best_question + "\n" + "\n".join("- " + opt for opt in mcqs) |
|
|
|
iface = gr.Interface( |
|
fn=generate_question, |
|
inputs=[context, answer, question_type], |
|
outputs=question |
|
) |
|
|
|
iface.launch(debug=True, share=False) |
|
|