|
import os |
|
import torch |
|
from vllm import LLM, SamplingParams |
|
import logging |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ChallengePromptGenerator: |
|
def __init__( |
|
self, |
|
model_local_dir="checkpoint-15000", |
|
): |
|
self.generator = LLM( |
|
model_local_dir, |
|
dtype="bfloat16", |
|
) |
|
|
|
|
|
def infer_prompt( |
|
self, |
|
prompts, |
|
max_generation_length=77, |
|
beam_size=1, |
|
sampling_temperature=0.9, |
|
sampling_topk=1, |
|
sampling_topp=1, |
|
): |
|
added_prompts = [f"{self.generator.get_tokenizer().bos_token} {prompt}" for prompt in prompts] |
|
|
|
sampling_params = SamplingParams( |
|
max_tokens=max_generation_length, |
|
temperature=sampling_temperature, |
|
top_k=sampling_topk, |
|
top_p=sampling_topp, |
|
use_beam_search=(beam_size > 1), |
|
) |
|
|
|
outputs = self.generator.generate(added_prompts, sampling_params) |
|
out = [] |
|
for i in range(len(outputs)): |
|
tmp_out = prompts[i] + outputs[i].outputs[0].text |
|
if tmp_out[-1] != ".": |
|
tmp_out = ".".join(tmp_out.split(".")[:-1]) + "." |
|
out.append(tmp_out) |
|
return out |
|
|