Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from vllm import LLM, SamplingParams | |
import logging | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
class ChallengePromptGenerator: | |
def __init__( | |
self, | |
model_local_dir="./checkpoint-15000", | |
): | |
self.generator = LLM( | |
model_local_dir, | |
) | |
def infer_prompt( | |
self, | |
prompts, | |
max_generation_length=77, | |
beam_size=1, | |
sampling_temperature=0.9, | |
sampling_topk=1, | |
sampling_topp=1, | |
): | |
added_prompts = [f"{self.generator.get_tokenizer().bos_token} {prompt}" for prompt in prompts] | |
sampling_params = SamplingParams( | |
max_tokens=max_generation_length, | |
temperature=sampling_temperature, | |
top_k=sampling_topk, | |
top_p=sampling_topp, | |
use_beam_search=(beam_size > 1), | |
) | |
outputs = self.generator.generate(added_prompts, sampling_params) | |
out = [] | |
for i in range(len(outputs)): | |
tmp_out = prompts[i] + outputs[i].outputs[0].text | |
# droop last unfished sentence | |
if tmp_out[-1] != ".": | |
tmp_out = ".".join(tmp_out.split(".")[:-1]) | |
out.append(tmp_out) | |
return out | |