File size: 5,437 Bytes
1ddad36
 
 
 
 
 
 
 
 
 
69c8b7c
b38e8d6
1ddad36
 
 
 
 
 
 
 
 
da136c8
7cda596
da136c8
1ddad36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b38e8d6
 
 
 
 
 
 
 
1ddad36
 
 
 
b38e8d6
 
 
 
 
 
 
69c8b7c
b38e8d6
 
 
 
 
 
 
 
 
 
 
1857d81
 
 
 
 
 
 
 
 
 
b38e8d6
 
1ddad36
 
 
 
b38e8d6
 
1ddad36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b38e8d6
1ddad36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b38e8d6
1ddad36
 
 
 
 
 
b38e8d6
1ddad36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Huggingface Transformers
from transformers import (
    MT5ForConditionalGeneration,
    MT5Tokenizer,
    MBartForConditionalGeneration,
    MBartTokenizer,
    T5ForConditionalGeneration,
    T5TokenizerFast,
    GenerationConfig,
)
from api.ioprocess import processInputAndResults, check_and_insert_space,processPunctuation


import torch
import re


"""
    Some global variables
    Add path to the models here
"""
mt5ModelPath = "duraad/nep-spell-hft-final"
mbartModelPath = "houdini001/happytt_mBART_plus_11_v2"
vartat5ModelPath = "Ayus077BCT014Bhandari/vartat5-using-100K-plus-27"


"""
    Function: generate

    This function takes a model name and input text as parameters and 
    returns the output text generated by the specified model. 
    It supports multiple models such as mT5, mBART, and VartaT5. 
    If the specified model is not available, 
    it returns a message indicating the unavailability of the model.

    Parameters:
    - model (str): Name of the model to use for text generation.
    - input (str): Input text for the model to generate output from.

    Returns:
    - str: Output text generated by the specified model or a message indicating model unavailability.
"""


def generate(model, input):

    in_sentences = inputSentenceList(input)
    out_sentences = processSentenceList(model, in_sentences)

    # TODO: add span for each before joining
    result = []
    for i, o in zip(in_sentences, out_sentences):
        result.append(processInputAndResults(i, o))
    return " ".join(result)

    # काकाले काकिलाइ माया गर्नू हुन्छ।


def inputSentenceList(input):
    # Define a regex pattern to split sentences
    # We'll split on periods, question marks, and exclamation marks
    sentence_pattern = r"(?<=[।?!\n])\s+"
    # Split the Nepali text into sentences
    sentences = re.split(sentence_pattern, input)
    for i, s in enumerate(sentences):
        sentences[i] = processPunctuation(s)
    return sentences


"""
    For working with paragraph processing
"""


def processSentenceList(model, inputSentenceList):
    out_sentence = []
    for s in inputSentenceList:
        if(len(s)>2):
            if model == "mT5":
                out_s = mt5Inference(s)
            elif model == "mBART":
                out_s = mbartInference(s)
            elif model == "VartaT5":
                out_s = vartat5Inference(s)
            else:
                return f"Model: {model} not available"
            out_sentence.append(out_s[0]["sequence"])
    return out_sentence


"""
    Below are the 3 different models for inference
"""


def mt5Inference(input):
    print("Processing mt5")

    model = MT5ForConditionalGeneration.from_pretrained(mt5ModelPath)
    tokenizer = MT5Tokenizer.from_pretrained(mt5ModelPath)
    input_ids = tokenizer("grammar: " + input, return_tensors="pt").input_ids
    outputs = model.generate(
        input_ids=input_ids,
        max_length=512,
        num_beams=5,
        num_return_sequences=5,
        return_dict_in_generate=True,
        output_scores=True,
    )
    sequences = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
    return postProcessOutput(sequences, outputs["sequences_scores"])


def mbartInference(input):
    print("Processing mbart")
    tokenizer = MBartTokenizer.from_pretrained(
        mbartModelPath, src_lang="ne_NP", tgt_lang="ne_NP"
    )
    model = MBartForConditionalGeneration.from_pretrained(mbartModelPath)
    inputs = tokenizer("grammar: " + input, return_tensors="pt")
    outputs = model.generate(
        **inputs,
        decoder_start_token_id=tokenizer.lang_code_to_id["ne_NP"],
        max_length=512,
        num_beams=5,
        num_return_sequences=5,
        return_dict_in_generate=True,
        output_scores=True,
    )
    sequences = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
    return postProcessOutput(sequences, outputs["sequences_scores"])
    # return outputs


def vartat5Inference(input):
    print("Processing varta")
    model = T5ForConditionalGeneration.from_pretrained(vartat5ModelPath)
    # return "model ready"
    tokenizer = T5TokenizerFast.from_pretrained(vartat5ModelPath)
    input_ids = tokenizer("grammar: " + input, return_tensors="pt")
    outputs = model.generate(
        **input_ids,
        max_length=512,
        num_beams=5,
        num_return_sequences=5,
        return_dict_in_generate=True,
        output_scores=True,
    )
    sequences = tokenizer.batch_decode(outputs["sequences"], skip_special_tokens=True)
    return postProcessOutput(sequences, outputs["sequences_scores"])


"""
    Post processing the model output
"""


def postProcessOutput(sequences, sequences_scores):
    probabilities = torch.exp(sequences_scores)
    unique_sequences = set()
    # Initialize the list to store filtered items
    filtered_outputs = []

    # Iterate through sequences and formatted_scores
    for sequence, score in zip(sequences, probabilities):
        # Check if the sequence is not in the set of unique sequences
        if sequence not in unique_sequences:
            # Add the sequence to the set of unique sequences
            unique_sequences.add(sequence)
            # Append the sequence and score to the filtered_outputs list
            filtered_outputs.append({"sequence": sequence, "score": score.item()})

    return filtered_outputs