acecalisto3 commited on
Commit
1e8f87a
·
verified ·
1 Parent(s): a3b29d1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional
3
+ from transformers import (
4
+ AutoConfig,
5
+ AutoModelForSequenceClassification,
6
+ AutoTokenizer,
7
+ DataCollatorWithPadding,
8
+ HfArgumentParser,
9
+ PreTrainedModel,
10
+ PretrainedConfig,
11
+ Trainer,
12
+ training_args,
13
+ )
14
+
15
+ class MockOpenAI:
16
+ """
17
+ A mock implementation of OpenAI's API using Hugging Face's pipeline for text generation.
18
+
19
+ :param api_key: Your Hugging Face API key, required for authentication.
20
+ :param base_url: The base URL for the Hugging Face API, defaults to the production URL.
21
+ :param model_name: The name of the pretrained model to use for text generation, defaults to 'gpt2'.
22
+ :param max_tokens: The maximum number of tokens to generate in the response, defaults to 50.
23
+ """
24
+ def __init__(
25
+ self,
26
+ api_key: Optional[str] = None,
27
+ base_url: Optional[str] = None,
28
+ model_name: Optional[str] = "gpt2",
29
+ max_tokens: int = 50,
30
+ ):
31
+ self.api_key = api_key or os.environ.get("HUGGING_FACE_API_KEY")
32
+ self.base_url = base_url or "https://api-inference.huggingface.co/models"
33
+ self.model_name = model_name
34
+ self.max_tokens = max_tokens
35
+ self.config = AutoConfig.from_pretrained(self.model_name)
36
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
37
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, config=self.config)
38
+ self.data_collator = DataCollatorWithPadding(self.tokenizer)
39
+ self.trainer = Trainer(
40
+ model=self.model,
41
+ args=training_args(
42
+ output_dir="./",
43
+ num_train_epochs=1,
44
+ learning_rate=1e-5,
45
+ per_device_train_batch_size=16,
46
+ per_device_eval_batch_size=16,
47
+ evaluation_strategy="epoch",
48
+ ),
49
+ )
50
+
51
+ class Chat:
52
+ def __init__(self, mock_openai: MockOpenAI):
53
+ self.mock_openai = mock_openai
54
+
55
+ class Completions:
56
+ def __init__(self, mock_openai: MockOpenAI):
57
+ self.mock_openai = mock_openai
58
+
59
+ def create(
60
+ self,
61
+ messages: List[Dict[str, str]],
62
+ model: Optional[str] = None,
63
+ max_tokens: int = 50,
64
+ **kwargs,
65
+ ):
66
+ """
67
+ Generate a text completion based on the given messages.
68
+
69
+ :param messages: List of message objects, each containing 'role' and 'content'.
70
+ :param model: The name of the pretrained model to use for text generation, defaults to 'gpt2'.
71
+ :param max_tokens: The maximum number of tokens to generate in the response, defaults to 50.
72
+ :param kwargs: Additional keyword arguments to pass to the pipeline function.
73
+ :return: A dictionary containing the generated text.
74
+ """
75
+ if not self.mock_openai.config.is_decoder:
76
+ raise ValueError("This model is not a decoder.")
77
+
78
+ model_name = model or self.mock_openai.model_name
79
+ prompt = " ".join([msg["content"] for msg in messages])
80
+
81
+ inputs = self.mock_openai.tokenizer(prompt, padding="max_length", truncation=True)
82
+ outputs = self.mock_openai.trainer.predict(inputs.to_tensor(pad_to_multiple_of=self.mock_openai.config.max_length))
83
+ result = self.mock_openai.tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+
85
+ if max_tokens is not None and len(result) > max_tokens:
86
+ result = result[:max_tokens]
87
+
88
+ return result
89
+
90
+ @property
91
+ def chat(self):
92
+ """
93
+ Get the Chat class instance with the pretrained model for text generation.
94
+
95
+ :return: The Chat class instance.
96
+ """
97
+ return self.Chat(self)
98
+
99
+ # Example usage
100
+ if __name__ == "__main__":
101
+ parser = HfArgumentParser(description="Mock OpenAI API using Hugging Face's pipeline for text generation.")
102
+ parser.add_argument("--model_name", default="gpt2", help="The name of the pretrained model to use for text generation.")
103
+ parser.add_argument("--max_tokens", type=int, default=50, help="The maximum number of tokens to generate in the response.")
104
+ args = parser.parse_args()
105
+ client = MockOpenAI(model_name=args.model_name, max_tokens=args.max_tokens)
106
+ chat_completion = client.chat.Completions().create(
107
+ messages=[
108
+ {
109
+ "role": "system",
110
+ "content": "You are a helpful assistant.",
111
+ },
112
+ {
113
+ "role": "user",
114
+ "content": "What is deep learning?",
115
+ }
116
+ ]
117
+ )
118
+
119
+ print(chat_completion)