import uuid import torch import json import pandas as pd from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, GPT2LMHeadModel, GenerationConfig import numpy as np class CyberClassic(torch.nn.Module): def __init__( self, max_length: int, startings_path: str ) -> None: super().__init__() self.max_length = max_length self.startings = pd.read_csv(startings_path) self.tokenizer = AutoTokenizer.from_pretrained('Roaoch/CyberClassic-Generator') self.generator: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained('Roaoch/CyberClassic-Generator') self.discriminator_tokenizer = AutoTokenizer.from_pretrained('Roaoch/CyberClassic-Discriminator') self.discriminator = AutoModelForSequenceClassification.from_pretrained('Roaoch/CyberClassic-Discriminator') self.generation_config = GenerationConfig( max_new_tokens=max_length, num_beams=6, early_stopping=True, do_sample=True, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id ) def generate(self) -> str: starts = self.startings['text'].values[np.random.randint(0, len(self.startings), 4)].tolist() tokens = self.tokenizer(starts, return_tensors='pt', padding=True, truncation=True) generated = self.generator.generate(**tokens, generation_config=self.generation_config) decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True) decoded_tokens = self.discriminator_tokenizer(decoded, return_tensors='pt', padding=True, truncation=True) score = self.discriminator(**decoded_tokens) index = int(torch.argmax(score.logits)) return decoded[index] def answer(self, promt: str) -> str: promt = promt + '. ' length = len(promt) promt_tokens = self.tokenizer(promt, return_tensors='pt') output = self.generator.generate( **promt_tokens, generation_config=self.generation_config, ) decoded = self.tokenizer.batch_decode(output, skip_special_tokens=True) return decoded[0][length:].strip()