from transformers import AutoTokenizer

# from huggingface_hub import notebook_login
# notebook_login()

from transformers import AutoTokenizer
tokenizer1 = AutoTokenizer.from_pretrained("Reham721/Subjective_QG")
tokenizer2 = AutoTokenizer.from_pretrained("Reham721/MCQs")

from transformers import AutoModelForSeq2SeqLM

model1 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/Subjective_QG")
model2 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/MCQs")


from arabert.preprocess import ArabertPreprocessor
from transformers import pipeline

prep = ArabertPreprocessor("aubmindlab/araelectra-base-discriminator") #or empty string it's the same
qa_pipe =pipeline("question-answering",model="wissamantoun/araelectra-base-artydiqa")

def generate_questions(model,tokenizer, input_sequence):

    # Tokenize input sequence
    input_ids = tokenizer.encode(input_sequence, return_tensors='pt')

    # Generate questions
    outputs = model.generate(
    input_ids=input_ids,
    max_length=200, # Set a shorter maximum length for shorter questions
    num_beams=3, # Use fewer beams for faster generation and to avoid overfitting
    no_repeat_ngram_size=3, # Allow some repetition to avoid overly generic questions
    early_stopping=True, # Stop generation when all beams are finished
    temperature=1, # Use a lower temperature for more conservative questions
    num_return_sequences=3, # Generate more questions per input
)

    # Decode questions
    questions = []
    for output in outputs:
        output_text = tokenizer.decode(output, skip_special_tokens=True)
        questions.append(output_text)

    return questions

def get_sorted_questions(questions, context):
    dic = {}
    context = prep.preprocess(context)
    for question in questions:
        print(question)
        result = qa_pipe(question=question,context=context)
        print(result)
        dic.update({question: result["score"]})

    return dict(sorted(dic.items(), key=lambda item: item[1], reverse=True))



import unicodedata
import arabic_reshaper
from bidi.algorithm import get_display

def is_arabic(text):
    # Reshape the text using the arabic_reshaper library
    reshaped_text = arabic_reshaper.reshape(text)
    # Determine the directionality of the text using the get_display() function from the bidi algorithm library
    bidi_text = get_display(reshaped_text)
    # Check if the text contains any non-Arabic letters
    for char in bidi_text:
        if char.isalpha() and unicodedata.name(char).startswith('ARABIC') == False:
            return False
    return True

import random
import re
def generate_distractors(question, answer, context, num_distractors=3, k=10):
    input_sequence = f'{question} <sep> {answer} <sep> {context}'
    input_ids = tokenizer2.encode(input_sequence, return_tensors='pt')

    # Generate distractors using model.generate()
    outputs = model2.generate(
        input_ids,
        do_sample=True, max_length=50, top_k=50, top_p=0.95, num_return_sequences=num_distractors, no_repeat_ngram_size=2)

    # Convert outputs to list of strings
    distractors = []
    for output in outputs:
        decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
        distractor_elements = [re.sub(r'<[^>]*>', '', element.strip()) for element in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if element]
        distractor_elements = [element for element in distractor_elements if element]
        distractor_elements = [element for element in distractor_elements if is_arabic(element)]
        distractors.append(distractor_elements)
    distractors = [element for sublist in distractors for element in sublist]


    # Remove duplicate distractors
    unique_distractors = []
    for distractor in distractors:
        if distractor not in unique_distractors and distractor != answer:
            unique_distractors.append(distractor)


    # If there are not enough unique distractors, generate more until there are
    while len(unique_distractors) < num_distractors:

        outputs = model2.generate(
            input_ids,
            do_sample=True,
            max_length=50,
            top_k=50,
            top_p=0.95,
            num_return_sequences=num_distractors-len(unique_distractors),
            no_repeat_ngram_size=2)
        for output in outputs:
            decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
            distractor_elements = [re.sub(r'<[^>]*>', '', element.strip()) for element in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if element]
            distractor_elements = [element for element in distractor_elements if element]
            distractor_elements = [element for element in distractor_elements if is_arabic(element)]
            if decoded_output not in unique_distractors and decoded_output not in unique_distractors and decoded_output != answer:
                unique_distractors.append(decoded_output)
            if len(unique_distractors) >= num_distractors:
                break

    random.shuffle(unique_distractors)

    # Select k top distractors if more than k obtained in step 2
    if len(unique_distractors) > k:
        unique_distractors = sorted(unique_distractors, key=lambda x: random.random())[:k]

    # Select num_distractors distractors
    distractor_subset = random.sample(unique_distractors, num_distractors)

    return distractor_subset



import gradio as gr

context = gr.inputs.Textbox(lines=5,placeholder="Enter paragraph/context here...")
answer = gr.inputs.Textbox(lines=3, placeholder="Enter answer/keyword here...")
question_type = gr.inputs.Radio(choices=["Subjective", "MCQ"], label="Question type")
question = gr.outputs.Textbox( type="text", label="Question")

def generate_question(context,answer,question_type):
    article = answer+"<sep>"+context
    output = generate_questions(model1, tokenizer1, article)
    result = get_sorted_questions(output, context)
    if question_type == "Subjective":
        return next(iter(result))
    else:
        mcqs = generate_distractors(question, answer, context)
        mcqs[3] = answer
        # random.shuffle(mcqs)
        return next(iter(result))+"\n"+"-" + mcqs[0]+"\n"+"-" + mcqs[1]+"\n"+"-" + mcqs[2]+"\n" +"-" + mcqs[3] +"\n"

iface = gr.Interface(
  fn=generate_question,
  inputs=[context,answer,question_type],
  outputs=question,
  list_outputs=True,
  rtl=True)

iface.launch(debug=True,share=False) # will create a temporary sharable link