import os from typing import Dict, List, Optional from dataclasses import dataclass, field from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, HfArgumentParser, Trainer, TrainingArguments, ) @dataclass class ModelArguments: model_name: str = field( default="gpt2", metadata={"help": "The name of the pretrained model to use for text generation."} ) max_tokens: int = field( default=50, metadata={"help": "The maximum number of tokens to generate in the response."} ) class MockOpenAI: """ A mock implementation of OpenAI's API using Hugging Face's pipeline for text generation. :param api_key: Your Hugging Face API key, required for authentication. :param base_url: The base URL for the Hugging Face API, defaults to the production URL. :param model_name: The name of the pretrained model to use for text generation, defaults to 'gpt2'. :param max_tokens: The maximum number of tokens to generate in the response, defaults to 50. """ def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model_name: Optional[str] = "gpt2", max_tokens: int = 50, ): self.api_key = api_key or os.environ.get("HUGGING_FACE_API_KEY") self.base_url = base_url or "https://api-inference.huggingface.co/models" self.model_name = model_name self.max_tokens = max_tokens self.config = AutoConfig.from_pretrained(self.model_name) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, config=self.config) self.data_collator = DataCollatorWithPadding(self.tokenizer) self.trainer = Trainer( model=self.model, args=TrainingArguments( output_dir="./", num_train_epochs=1, learning_rate=1e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, evaluation_strategy="epoch", ), ) class Chat: def __init__(self, mock_openai: 'MockOpenAI'): self.mock_openai = mock_openai class Completions: def __init__(self, mock_openai: 'MockOpenAI'): self.mock_openai = mock_openai def create( self, messages: List[Dict[str, str]], model: Optional[str] = None, max_tokens: int = 50, **kwargs, ): """ Generate a text completion based on the given messages. :param messages: List of message objects, each containing 'role' and 'content'. :param model: The name of the pretrained model to use for text generation, defaults to 'gpt2'. :param max_tokens: The maximum number of tokens to generate in the response, defaults to 50. :param kwargs: Additional keyword arguments to pass to the pipeline function. :return: A dictionary containing the generated text. """ if not self.mock_openai.config.is_decoder: raise ValueError("This model is not a decoder.") model_name = model or self.mock_openai.model_name prompt = " ".join([msg["content"] for msg in messages]) inputs = self.mock_openai.tokenizer(prompt, padding="max_length", truncation=True) outputs = self.mock_openai.trainer.predict(inputs.to_tensor(pad_to_multiple_of=self.mock_openai.config.max_length)) result = self.mock_openai.tokenizer.decode(outputs[0], skip_special_tokens=True) if max_tokens is not None and len(result) > max_tokens: result = result[:max_tokens] return result @property def chat(self): """ Get the Chat class instance with the pretrained model for text generation. :return: The Chat class instance. """ return self.Chat(self) # Example usage if __name__ == "__main__": parser = HfArgumentParser((ModelArguments,)) model_args = parser.parse_args_into_dataclasses()[0] client = MockOpenAI(model_name=model_args.model_name, max_tokens=model_args.max_tokens) chat_completion = client.chat.Completions().create( messages=[ { "role": "system", "content": "You are a helpful assistant.", }, { "role": "user", "content": "What is deep learning?", } ] ) print(chat_completion)