from typing import Dict, List, Any from transformers import AutoModelForCausalLM, AutoTokenizer import torch import logging logging.basicConfig(level=logging.INFO) LOGGER = logging.getLogger(__name__) # Prompts for the different tasks START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n" END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n" START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n" MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n" END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n" START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n" END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n" START_PROMPT_TASK = { 1: START_PROMPT_TASK1, 2: START_PROMPT_TASK2, 3: START_PROMPT_TASK3, } END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3} SEP = "\n\n" class EndpointHandler: def __init__(self, path=""): self.model = AutoModelForCausalLM.from_pretrained( path, device_map="auto", torch_dtype=torch.bfloat16 ) LOGGER.info(f"Inference model loaded from {path}") LOGGER.info(f"Model device: {self.model.device}") # Fix the pad and bos tokens to avoid bug in the tokenizer pad_token = "" bos_token = "<|endoftext|>" self.tokenizer = AutoTokenizer.from_pretrained( "AI-Sweden-Models/gpt-sw3-6.7b", pad_token=pad_token, bos_token=bos_token ) def check_valid_inputs( self, input_a: str, input_b: str, task: int ) -> bool: """ Check if the inputs are valid """ if task not in [1, 2, 3]: return False if task == 1 or task == 3: if input_a is None: return False elif task == 2: if input_a is None or input_b is None: return False return True def tokenize_input(self, input_a: str, input_b: str, task: int) -> List[int]: """ Tokenize the input """ if task == 1 or task == 3: tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"] tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"] tokenized_sentence = self.tokenizer(input_a + SEP)["input_ids"] concatted_data = ( [self.tokenizer.bos_token_id] + tokenized_start + tokenized_sentence + tokenized_end ) elif task == 2: tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"] tokenized_middle = self.tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"] tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"] tokenized_sentence_a = self.tokenizer(input_a + SEP)["input_ids"] tokenized_sentence_b = self.tokenizer(input_b + SEP)["input_ids"] concatted_data = ( [self.tokenizer.bos_token_id] + tokenized_start + tokenized_sentence_a + tokenized_middle + tokenized_sentence_b + tokenized_end ) return concatted_data def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ LOGGER.info(f"Received data: {data}") # Get inputs input_a = data.pop("input_a", None) input_b = data.pop("input_b", None) task = data.pop("task", None) parameters = data.pop("parameters", {}) # Check valid inputs if not self.check_valid_inputs(input_a, input_b, task): return [{"error": "Invalid inputs"}] if "max_new_tokens" not in parameters and "max_length" not in parameters: parameters["max_new_tokens"] = 512 # Tokenize the input tokenized_input = self.tokenize_input(input_a, input_b, task) # Move the input to the device input_ids = torch.tensor(tokenized_input).to(self.model.device) input_ids = input_ids.unsqueeze(0) # Generate the output output = self.model.generate(input_ids, **parameters) # Decode only the new part of the output decoded_output = self.tokenizer.decode( output[0][len(tokenized_input) :], skip_special_tokens=True ).strip() return [{"output": decoded_output}]