MPT7BTest / InstructionTextGenerationPipeline.py
danavirtual's picture
requirements for ext class
01ca166
raw
history blame
2.72 kB
from torch import *
import gradio as gr
import requests
import torch
import transformers
import einops
###
from typing import Any, Dict, Tuple
import warnings
import datetime
import os
from threading import Event, Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import config
class InstructionTextGenerationPipeline:
def __init__(
self,
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
use_auth_token=None,
) -> None:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)
if tokenizer.pad_token_id is None:
warnings.warn(
"pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
self.tokenizer = tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.eval()
self.model.to(device=device, dtype=torch_dtype)
self.generate_kwargs = {
"temperature": 0.5,
"top_p": 0.92,
"top_k": 0,
"max_new_tokens": 512,
"use_cache": True,
"do_sample": True,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
"repetition_penalty": 1.1, # 1.0 means no penalty, > 1.0 means penalty, 1.2 from CTRL paper
}
def format_instruction(self, instruction):
return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
def __call__(
self, instruction: str, **generate_kwargs: Dict[str, Any]
) -> Tuple[str, str, float]:
s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
input_ids = self.tokenizer(s, return_tensors="pt").input_ids
input_ids = input_ids.to(self.model.device)
gkw = {**self.generate_kwargs, **generate_kwargs}
with torch.no_grad():
output_ids = self.model.generate(input_ids, **gkw)
# Slice the output_ids tensor to get only new tokens
new_tokens = output_ids[0, len(input_ids[0]) :]
output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return output_text