import os import torch from vllm import LLM, SamplingParams import logging # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class ChallengePromptGenerator: def __init__( self, model_local_dir="checkpoint-15000", ): self.generator = LLM( model_local_dir, dtype="bfloat16", ) def infer_prompt( self, prompts, max_generation_length=77, beam_size=1, sampling_temperature=0.9, sampling_topk=1, sampling_topp=1, ): added_prompts = [f"{self.generator.get_tokenizer().bos_token} {prompt}" for prompt in prompts] sampling_params = SamplingParams( max_tokens=max_generation_length, temperature=sampling_temperature, top_k=sampling_topk, top_p=sampling_topp, use_beam_search=(beam_size > 1), ) outputs = self.generator.generate(added_prompts, sampling_params) out = [] for i in range(len(outputs)): tmp_out = prompts[i] + outputs[i].outputs[0].text if tmp_out[-1] != ".": tmp_out = ".".join(tmp_out.split(".")[:-1]) + "." out.append(tmp_out) return out