Yjhhh commited on
Commit
cdf7569
·
verified ·
1 Parent(s): 27c0505

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +39 -82
main.py CHANGED
@@ -6,11 +6,9 @@ from transformers import (
6
  AutoTokenizer,
7
  AutoModelForSequenceClassification,
8
  AutoModelForCausalLM,
 
 
9
  )
10
- import torch
11
- import torch.nn as nn
12
- from torch.utils.data import DataLoader, Dataset
13
- from torch.optim import AdamW
14
  from fastapi import FastAPI, HTTPException, Request
15
  from fastapi.responses import HTMLResponse
16
  import multiprocessing
@@ -65,38 +63,26 @@ class ChatbotService:
65
 
66
  chatbot_service = ChatbotService()
67
 
68
- class UnifiedModel(nn.Module):
69
- def __init__(self, models):
70
- super(UnifiedModel, self).__init__()
71
- self.models = nn.ModuleList(models)
72
- hidden_size = self.models[0].config.hidden_size
73
- self.projection = nn.Linear(len(models) * 3, 768)
74
- self.classifier = nn.Linear(hidden_size, 3)
75
-
76
- def forward(self, input_ids, attention_mask):
77
- hidden_states = []
78
- for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
79
- outputs = model(input_ids=input_id, attention_mask=attn_mask)
80
- hidden_states.append(outputs.logits)
81
- concatenated_hidden_states = torch.cat(hidden_states, dim=1)
82
- projected_features = self.projection(concatenated_hidden_states)
83
- logits = self.classifier(projected_features)
84
- return logits
85
 
86
  @staticmethod
87
  def load_model_from_redis(redis_client):
88
  model_name = "unified_model"
89
- model_data_bytes = redis_client.get(f"model:{model_name}")
90
- if model_data_bytes:
91
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
92
- model.load_state_dict(torch.load(model_data_bytes))
 
 
93
  else:
94
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
95
- return UnifiedModel([model, model])
96
 
97
  class SyntheticDataset(Dataset):
98
- def __init__(self, tokenizers, data):
99
- self.tokenizers = tokenizers
100
  self.data = data
101
 
102
  def __len__(self):
@@ -106,13 +92,8 @@ class SyntheticDataset(Dataset):
106
  item = self.data[idx]
107
  text = item['text']
108
  label = item['label']
109
- tokenized = {}
110
- for name, tokenizer in self.tokenizers.items():
111
- tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128)
112
- tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"])
113
- tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"])
114
- tokenized["labels"] = torch.tensor(label)
115
- return tokenized
116
 
117
  conversation_history = {}
118
 
@@ -121,22 +102,10 @@ async def process(request: Request):
121
  data = await request.json()
122
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
123
 
124
- tokenizers = {}
125
- models = {}
126
-
127
- model_name = "unified_model"
128
  tokenizer_name = "unified_tokenizer"
129
 
130
- model_data_bytes = redis_client.get(f"model:{model_name}")
131
  tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
132
 
133
- if model_data_bytes:
134
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
135
- model.load_state_dict(torch.load(model_data_bytes))
136
- else:
137
- model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
138
- models[model_name] = model
139
-
140
  if tokenizer_data_bytes:
141
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
142
  tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
@@ -144,9 +113,8 @@ async def process(request: Request):
144
  else:
145
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
146
  tokenizer.pad_token = tokenizer.eos_token
147
- tokenizers[tokenizer_name] = tokenizer
148
 
149
- unified_model = UnifiedModel(list(models.values()))
150
  unified_model.to(torch.device("cpu"))
151
 
152
  if data.get("train"):
@@ -170,11 +138,9 @@ async def process(request: Request):
170
  conversation_history[user_id] = []
171
  conversation_history[user_id].append(text)
172
  contextualized_text = " ".join(conversation_history[user_id][-3:])
173
- tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
174
- input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
175
- attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
176
  with torch.no_grad():
177
- logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
178
  predicted_class = torch.argmax(logits, dim=-1).item()
179
  response = chatbot_service.get_response(user_id, contextualized_text, language)
180
  redis_client.rpush("training_queue", json.dumps({
@@ -327,35 +293,26 @@ def train_unified_model():
327
  if training_queue:
328
  for item in training_queue:
329
  item_data = json.loads(item)
330
- tokenizers = {name: AutoTokenizer.from_pretrained("gpt2") for name in item_data["tokenizers"]}
331
- for tokenizer in tokenizers.values():
332
- tokenizer.pad_token = tokenizer.eos_token
 
 
333
  data = item_data["data"]
334
- dataset = SyntheticDataset(tokenizers, data)
335
- dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
336
-
337
- model = UnifiedModel([AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)])
338
- optimizer = AdamW(model.parameters(), lr=1e-5)
339
- criterion = nn.CrossEntropyLoss()
340
-
341
- for epoch in range(3):
342
- model.train()
343
- for batch in dataloader:
344
- input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in tokenizers]
345
- attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in tokenizers]
346
- labels = batch["labels"].to("cpu")
347
-
348
- optimizer.zero_grad()
349
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
350
- loss = criterion(outputs, labels)
351
- loss.backward()
352
- optimizer.step()
353
-
354
- model_data_path = "model_data.pt"
355
- torch.save(model.state_dict(), model_data_path)
356
- with open(model_data_path, "rb") as f:
357
- model_data_bytes = f.read()
358
- redis_client.set(f"model:unified_model", model_data_bytes)
359
  redis_client.delete("training_queue")
360
  time.sleep(60)
361
 
 
6
  AutoTokenizer,
7
  AutoModelForSequenceClassification,
8
  AutoModelForCausalLM,
9
+ TrainingArguments,
10
+ Trainer,
11
  )
 
 
 
 
12
  from fastapi import FastAPI, HTTPException, Request
13
  from fastapi.responses import HTMLResponse
14
  import multiprocessing
 
63
 
64
  chatbot_service = ChatbotService()
65
 
66
+ class UnifiedModel(AutoModelForSequenceClassification):
67
+ def __init__(self, config):
68
+ super().__init__(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  @staticmethod
71
  def load_model_from_redis(redis_client):
72
  model_name = "unified_model"
73
+ model_path = f"models/{model_name}"
74
+ if redis_client.exists(f"model:{model_name}"):
75
+ redis_client.delete(f"model:{model_name}")
76
+ if not os.path.exists(model_path):
77
+ model = UnifiedModel.from_pretrained("gpt2", num_labels=3)
78
+ model.save_pretrained(model_path)
79
  else:
80
+ model = UnifiedModel.from_pretrained(model_path)
81
+ return model
82
 
83
  class SyntheticDataset(Dataset):
84
+ def __init__(self, tokenizer, data):
85
+ self.tokenizer = tokenizer
86
  self.data = data
87
 
88
  def __len__(self):
 
92
  item = self.data[idx]
93
  text = item['text']
94
  label = item['label']
95
+ tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
96
+ return {"input_ids": tokens["input_ids"].squeeze(), "attention_mask": tokens["attention_mask"].squeeze(), "labels": label}
 
 
 
 
 
97
 
98
  conversation_history = {}
99
 
 
102
  data = await request.json()
103
  redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
104
 
 
 
 
 
105
  tokenizer_name = "unified_tokenizer"
106
 
 
107
  tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
108
 
 
 
 
 
 
 
 
109
  if tokenizer_data_bytes:
110
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
111
  tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
 
113
  else:
114
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
115
  tokenizer.pad_token = tokenizer.eos_token
 
116
 
117
+ unified_model = UnifiedModel.load_model_from_redis(redis_client)
118
  unified_model.to(torch.device("cpu"))
119
 
120
  if data.get("train"):
 
138
  conversation_history[user_id] = []
139
  conversation_history[user_id].append(text)
140
  contextualized_text = " ".join(conversation_history[user_id][-3:])
141
+ tokenized_input = tokenizer(contextualized_text, return_tensors="pt")
 
 
142
  with torch.no_grad():
143
+ logits = unified_model(**tokenized_input).logits
144
  predicted_class = torch.argmax(logits, dim=-1).item()
145
  response = chatbot_service.get_response(user_id, contextualized_text, language)
146
  redis_client.rpush("training_queue", json.dumps({
 
293
  if training_queue:
294
  for item in training_queue:
295
  item_data = json.loads(item)
296
+ tokenizer_data = item_data["tokenizers"]
297
+ tokenizer_name = list(tokenizer_data.keys())[0]
298
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
299
+ tokenizer.add_tokens(json.loads(tokenizer_data[tokenizer_name]))
300
+ tokenizer.pad_token = tokenizer.eos_token
301
  data = item_data["data"]
302
+ dataset = SyntheticDataset(tokenizer, data)
303
+
304
+ model_name = "unified_model"
305
+ model_path = f"models/{model_name}"
306
+ model = UnifiedModel.from_pretrained(model_path)
307
+
308
+ training_args = TrainingArguments(
309
+ output_dir="./results",
310
+ per_device_train_batch_size=8,
311
+ num_train_epochs=3,
312
+ )
313
+ trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
314
+ trainer.train()
315
+ model.save_pretrained(model_path)
 
 
 
 
 
 
 
 
 
 
 
316
  redis_client.delete("training_queue")
317
  time.sleep(60)
318