Yjhhh commited on
Commit
ac5b7b0
·
verified ·
1 Parent(s): de2d386

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +81 -86
main.py CHANGED
@@ -46,6 +46,17 @@ class UnifiedModel(nn.Module):
46
  logits = self.classifier(concatenated_hidden_states)
47
  return logits
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  class SyntheticDataset(Dataset):
50
  def __init__(self, tokenizers, data):
51
  self.tokenizers = tokenizers
@@ -68,6 +79,20 @@ class SyntheticDataset(Dataset):
68
 
69
  @app.post("/process")
70
  async def process(request: Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  data = await request.json()
72
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
73
 
@@ -102,41 +127,13 @@ async def process(request: Request):
102
  if not user_data:
103
  user_data = [{"text": "Sample text for automatic training.", "label": 0}]
104
 
105
- train_dataset = SyntheticDataset(tokenizers, user_data)
106
- train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
107
-
108
- training_args = TrainingArguments(
109
- output_dir="memory",
110
- evaluation_strategy="epoch",
111
- learning_rate=5e-5,
112
- per_device_train_batch_size=8,
113
- per_device_eval_batch_size=8,
114
- num_train_epochs=10,
115
- weight_decay=0.01,
116
- logging_steps=10,
117
- optim="adamw_hf"
118
- )
119
-
120
- optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
121
- unified_model.train()
122
-
123
- for epoch in range(training_args.num_train_epochs):
124
- for batch in train_loader:
125
- input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in tokenizers.keys()]
126
- attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in tokenizers.keys()]
127
- labels = batch["labels"].to("cpu")
128
- outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
129
- loss = nn.CrossEntropyLoss()(outputs, labels)
130
- loss.backward()
131
- optimizer.step()
132
- optimizer.zero_grad()
133
-
134
- print(f"Epoch {epoch}, Loss {loss.item()}")
135
-
136
- print("Training complete.")
137
 
138
- push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name)
139
- return {"message": "Model trained and updated in Redis."}
140
 
141
  elif data.get("predict"):
142
  text = data['text']
@@ -155,6 +152,19 @@ async def process(request: Request):
155
 
156
  @app.post("/external_answer")
157
  async def external_answer(request: Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  data = await request.json()
159
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
160
 
@@ -162,26 +172,16 @@ async def external_answer(request: Request):
162
  if not question:
163
  raise HTTPException(status_code=400, detail="Question is required.")
164
 
165
- model_name = "unified_model"
166
- tokenizer_name = "unified_tokenizer"
167
-
168
- model_data_bytes = redis_client.get(f"model:{model_name}")
169
- tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
170
-
171
- if model_data_bytes:
172
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
173
- model.load_state_dict(torch.load(model_data_bytes))
174
- else:
175
- model = AutoModelForSequenceClassification.from_pretrained("gpt2")
176
 
 
177
  if tokenizer_data_bytes:
178
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
179
  tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
180
  else:
181
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
182
-
183
- unified_model = UnifiedModel([model])
184
- unified_model.to(torch.device("cpu"))
185
 
186
  tokenized_input = tokenizer(question, return_tensors="pt")
187
  input_ids = tokenized_input['input_ids']
@@ -192,45 +192,22 @@ async def external_answer(request: Request):
192
  predicted_class = torch.argmax(logits, dim=-1).item()
193
  response = {"answer": f"Response to '{question}' is class {predicted_class}"}
194
 
195
- extreme_training_data = [{"text": question, "label": predicted_class}]
196
- train_dataset = SyntheticDataset({tokenizer_name: tokenizer}, extreme_training_data)
197
- train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
198
-
199
- training_args = TrainingArguments(
200
- output_dir="memory",
201
- evaluation_strategy="epoch",
202
- learning_rate=5e-5,
203
- per_device_train_batch_size=8,
204
- per_device_eval_batch_size=8,
205
- num_train_epochs=10,
206
- weight_decay=0.01,
207
- logging_steps=10,
208
- optim="adamw_hf"
209
- )
210
-
211
- optimizer = AdamW(unified_model.parameters(), lr=training_args.learning_rate)
212
- unified_model.train()
213
-
214
- for epoch in range(training_args.num_train_epochs):
215
- for batch in train_loader:
216
- input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in [tokenizer_name]]
217
- attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in [tokenizer_name]]
218
- labels = batch["labels"].to("cpu")
219
- outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
220
- loss = nn.CrossEntropyLoss()(outputs, labels)
221
- loss.backward()
222
- optimizer.step()
223
- optimizer.zero_grad()
224
-
225
- print(f"Epoch {epoch}, Loss {loss.item()}")
226
-
227
- print("Extreme training complete.")
228
- push_to_redis({model_name: model}, {tokenizer_name: tokenizer}, redis_client, model_name, tokenizer_name)
229
 
230
  return response
231
 
232
  @app.get("/")
233
  async def get_home():
 
 
 
 
 
 
234
  html_code = """
235
  <!DOCTYPE html>
236
  <html>
@@ -264,6 +241,18 @@ async def get_home():
264
  return HTMLResponse(content=html_code)
265
 
266
  def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
 
 
 
 
 
 
 
 
 
 
 
 
267
  for model_name, model in models.items():
268
  torch.save(model.state_dict(), model_name)
269
  redis_client.set(f"model:{model_name}", open(model_name, "rb").read())
@@ -273,11 +262,14 @@ def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
273
  redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens))
274
 
275
  def continuous_training():
 
 
 
276
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
277
 
278
  while True:
279
  try:
280
- data = redis_client.get("training_queue")
281
  if data:
282
  data = json.loads(data)
283
  unified_model = UnifiedModel.load_model_from_redis(redis_client)
@@ -302,13 +294,16 @@ def continuous_training():
302
  print(f"Epoch {epoch}, Loss {loss.item()}")
303
 
304
  push_to_redis(unified_model.models, data["tokenizers"], redis_client, "unified_model", "unified_tokenizer")
305
- redis_client.delete("training_queue")
306
  time.sleep(10)
307
  except Exception as e:
308
  print(f"Error in continuous training: {e}")
309
  time.sleep(5)
310
 
311
  if __name__ == "__main__":
312
- continuous_training()
 
 
 
 
313
  import uvicorn
314
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
46
  logits = self.classifier(concatenated_hidden_states)
47
  return logits
48
 
49
+ @staticmethod
50
+ def load_model_from_redis(redis_client):
51
+ model_name = "unified_model"
52
+ model_data_bytes = redis_client.get(f"model:{model_name}")
53
+ if model_data_bytes:
54
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=2)
55
+ model.load_state_dict(torch.load(model_data_bytes))
56
+ else:
57
+ model = AutoModelForSequenceClassification.from_pretrained("gpt2")
58
+ return UnifiedModel([model])
59
+
60
  class SyntheticDataset(Dataset):
61
  def __init__(self, tokenizers, data):
62
  self.tokenizers = tokenizers
 
79
 
80
  @app.post("/process")
81
  async def process(request: Request):
82
+ """
83
+ Processes requests for training and prediction.
84
+
85
+ Args:
86
+ request (Request): The incoming request object.
87
+
88
+ Returns:
89
+ dict: A dictionary containing either a message indicating successful
90
+ training data submission or the model's prediction.
91
+
92
+ Raises:
93
+ HTTPException: If the request does not contain 'train' or 'predict'
94
+ keys.
95
+ """
96
  data = await request.json()
97
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
98
 
 
127
  if not user_data:
128
  user_data = [{"text": "Sample text for automatic training.", "label": 0}]
129
 
130
+ # Add user data to Redis queue for asynchronous training
131
+ redis_client.rpush("training_queue", json.dumps({
132
+ "tokenizers": {tokenizer_name: tokenizer.get_vocab()},
133
+ "data": user_data
134
+ }))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ return {"message": "Training data received. Model will be updated asynchronously."}
 
137
 
138
  elif data.get("predict"):
139
  text = data['text']
 
152
 
153
  @app.post("/external_answer")
154
  async def external_answer(request: Request):
155
+ """
156
+ Provides an answer to a question using the unified model and triggers
157
+ asynchronous training with the new question-answer pair.
158
+
159
+ Args:
160
+ request (Request): The incoming request object containing the question.
161
+
162
+ Returns:
163
+ dict: A dictionary containing the answer to the question.
164
+
165
+ Raises:
166
+ HTTPException: If the request does not contain a 'question' key.
167
+ """
168
  data = await request.json()
169
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
170
 
 
172
  if not question:
173
  raise HTTPException(status_code=400, detail="Question is required.")
174
 
175
+ # Load the model and tokenizer from Redis
176
+ unified_model = UnifiedModel.load_model_from_redis(redis_client)
177
+ unified_model.to(torch.device("cpu"))
 
 
 
 
 
 
 
 
178
 
179
+ tokenizer_data_bytes = redis_client.get(f"tokenizer:unified_tokenizer")
180
  if tokenizer_data_bytes:
181
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
182
  tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
183
  else:
184
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
 
 
185
 
186
  tokenized_input = tokenizer(question, return_tensors="pt")
187
  input_ids = tokenized_input['input_ids']
 
192
  predicted_class = torch.argmax(logits, dim=-1).item()
193
  response = {"answer": f"Response to '{question}' is class {predicted_class}"}
194
 
195
+ # Asynchronously train on the new data point
196
+ redis_client.rpush("training_queue", json.dumps({
197
+ "tokenizers": {"unified_tokenizer": tokenizer.get_vocab()},
198
+ "data": [{"text": question, "label": predicted_class}]
199
+ }))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  return response
202
 
203
  @app.get("/")
204
  async def get_home():
205
+ """
206
+ Serves a basic HTML page as the home route.
207
+
208
+ Returns:
209
+ HTMLResponse: The HTML content of the home page.
210
+ """
211
  html_code = """
212
  <!DOCTYPE html>
213
  <html>
 
241
  return HTMLResponse(content=html_code)
242
 
243
  def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
244
+ """
245
+ Saves the given models and tokenizers to Redis.
246
+
247
+ Args:
248
+ models (dict): A dictionary of model names and their corresponding
249
+ PyTorch models.
250
+ tokenizers (dict): A dictionary of tokenizer names and their
251
+ corresponding tokenizers.
252
+ redis_client: The Redis client instance.
253
+ model_name (str): The base name to use for saving the models.
254
+ tokenizer_name (str): The base name to use for saving the tokenizers.
255
+ """
256
  for model_name, model in models.items():
257
  torch.save(model.state_dict(), model_name)
258
  redis_client.set(f"model:{model_name}", open(model_name, "rb").read())
 
262
  redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens))
263
 
264
  def continuous_training():
265
+ """
266
+ Continuously checks for new training data in Redis and updates the model.
267
+ """
268
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
269
 
270
  while True:
271
  try:
272
+ data = redis_client.lpop("training_queue")
273
  if data:
274
  data = json.loads(data)
275
  unified_model = UnifiedModel.load_model_from_redis(redis_client)
 
294
  print(f"Epoch {epoch}, Loss {loss.item()}")
295
 
296
  push_to_redis(unified_model.models, data["tokenizers"], redis_client, "unified_model", "unified_tokenizer")
 
297
  time.sleep(10)
298
  except Exception as e:
299
  print(f"Error in continuous training: {e}")
300
  time.sleep(5)
301
 
302
  if __name__ == "__main__":
303
+ # Start the continuous training process in a separate process
304
+ training_process = multiprocessing.Process(target=continuous_training)
305
+ training_process.start()
306
+
307
+ # Run the FastAPI app
308
  import uvicorn
309
+ uvicorn.run(app, host="0.0.0.0", port=7860)