T2IPromptGenerator / model.py
tungdop2's picture
Add model files with Git LFS
a53340c
raw
history blame
1.4 kB
import os
import torch
from vllm import LLM, SamplingParams
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class ChallengePromptGenerator:
def __init__(
self,
model_local_dir="./checkpoint-15000",
):
self.generator = LLM(
model_local_dir,
)
def infer_prompt(
self,
prompts,
max_generation_length=77,
beam_size=1,
sampling_temperature=0.9,
sampling_topk=1,
sampling_topp=1,
):
added_prompts = [f"{self.generator.get_tokenizer().bos_token} {prompt}" for prompt in prompts]
sampling_params = SamplingParams(
max_tokens=max_generation_length,
temperature=sampling_temperature,
top_k=sampling_topk,
top_p=sampling_topp,
use_beam_search=(beam_size > 1),
)
outputs = self.generator.generate(added_prompts, sampling_params)
out = []
for i in range(len(outputs)):
tmp_out = prompts[i] + outputs[i].outputs[0].text
# droop last unfished sentence
if tmp_out[-1] != ".":
tmp_out = ".".join(tmp_out.split(".")[:-1])
out.append(tmp_out)
return out