|
from typing import Dict, List, Any |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
|
|
START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n" |
|
END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n" |
|
|
|
START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n" |
|
MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n" |
|
END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n" |
|
|
|
START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n" |
|
END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n" |
|
|
|
START_PROMPT_TASK = { |
|
1: START_PROMPT_TASK1, |
|
2: START_PROMPT_TASK2, |
|
3: START_PROMPT_TASK3, |
|
} |
|
END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3} |
|
|
|
SEP = "\n\n" |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
path, device_map="auto", torch_dtype=torch.bfloat16 |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
LOGGER.info(f"Inference model loaded from {path}") |
|
LOGGER.info(f"Model device: {self.model.device}") |
|
|
|
def check_valid_inputs(self, input_a: str, input_b: str, task: int) -> bool: |
|
""" |
|
Check if the inputs are valid |
|
""" |
|
if task not in [1, 2, 3]: |
|
return False |
|
if task == 1 or task == 3: |
|
if input_a is None: |
|
return False |
|
elif task == 2: |
|
if input_a is None or input_b is None: |
|
return False |
|
return True |
|
|
|
def tokenize_input(self, input_a: str, input_b: str, task: int) -> List[int]: |
|
""" |
|
Tokenize the input |
|
""" |
|
if task == 1 or task == 3: |
|
tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"] |
|
tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"] |
|
tokenized_sentence = self.tokenizer(input_a + SEP)["input_ids"] |
|
concatted_data = ( |
|
[self.tokenizer.bos_token_id] |
|
+ tokenized_start |
|
+ tokenized_sentence |
|
+ tokenized_end |
|
) |
|
elif task == 2: |
|
tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"] |
|
tokenized_middle = self.tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"] |
|
tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"] |
|
tokenized_sentence_a = self.tokenizer(input_a + SEP)["input_ids"] |
|
tokenized_sentence_b = self.tokenizer(input_b + SEP)["input_ids"] |
|
concatted_data = ( |
|
[self.tokenizer.bos_token_id] |
|
+ tokenized_start |
|
+ tokenized_sentence_a |
|
+ tokenized_middle |
|
+ tokenized_sentence_b |
|
+ tokenized_end |
|
) |
|
return concatted_data |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
LOGGER.info(f"Received data: {data}") |
|
|
|
|
|
input_a = data.pop("input_a", None) |
|
input_b = data.pop("input_b", None) |
|
task = data.pop("task", None) |
|
parameters = data.pop("parameters", {}) |
|
|
|
|
|
if not self.check_valid_inputs(input_a, input_b, task): |
|
return [{"error": "Invalid inputs"}] |
|
if "max_new_tokens" not in parameters and "max_length" not in parameters: |
|
parameters["max_new_tokens"] = 512 |
|
|
|
|
|
tokenized_input = self.tokenize_input(input_a, input_b, task) |
|
|
|
|
|
input_ids = torch.tensor(tokenized_input).to(self.model.device) |
|
input_ids = input_ids.unsqueeze(0) |
|
|
|
|
|
output = self.model.generate(input_ids, **parameters) |
|
|
|
|
|
decoded_output = self.tokenizer.decode( |
|
output[0][len(tokenized_input) :], skip_special_tokens=True |
|
).strip() |
|
return [{"output": decoded_output}] |
|
|