Update main.py
Browse files
main.py
CHANGED
@@ -6,11 +6,9 @@ from transformers import (
|
|
6 |
AutoTokenizer,
|
7 |
AutoModelForSequenceClassification,
|
8 |
AutoModelForCausalLM,
|
|
|
|
|
9 |
)
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
from torch.utils.data import DataLoader, Dataset
|
13 |
-
from torch.optim import AdamW
|
14 |
from fastapi import FastAPI, HTTPException, Request
|
15 |
from fastapi.responses import HTMLResponse
|
16 |
import multiprocessing
|
@@ -65,38 +63,26 @@ class ChatbotService:
|
|
65 |
|
66 |
chatbot_service = ChatbotService()
|
67 |
|
68 |
-
class UnifiedModel(
|
69 |
-
def __init__(self,
|
70 |
-
super(
|
71 |
-
self.models = nn.ModuleList(models)
|
72 |
-
hidden_size = self.models[0].config.hidden_size
|
73 |
-
self.projection = nn.Linear(len(models) * 3, 768)
|
74 |
-
self.classifier = nn.Linear(hidden_size, 3)
|
75 |
-
|
76 |
-
def forward(self, input_ids, attention_mask):
|
77 |
-
hidden_states = []
|
78 |
-
for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
|
79 |
-
outputs = model(input_ids=input_id, attention_mask=attn_mask)
|
80 |
-
hidden_states.append(outputs.logits)
|
81 |
-
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
|
82 |
-
projected_features = self.projection(concatenated_hidden_states)
|
83 |
-
logits = self.classifier(projected_features)
|
84 |
-
return logits
|
85 |
|
86 |
@staticmethod
|
87 |
def load_model_from_redis(redis_client):
|
88 |
model_name = "unified_model"
|
89 |
-
|
90 |
-
if
|
91 |
-
|
92 |
-
|
|
|
|
|
93 |
else:
|
94 |
-
model =
|
95 |
-
return
|
96 |
|
97 |
class SyntheticDataset(Dataset):
|
98 |
-
def __init__(self,
|
99 |
-
self.
|
100 |
self.data = data
|
101 |
|
102 |
def __len__(self):
|
@@ -106,13 +92,8 @@ class SyntheticDataset(Dataset):
|
|
106 |
item = self.data[idx]
|
107 |
text = item['text']
|
108 |
label = item['label']
|
109 |
-
|
110 |
-
|
111 |
-
tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128)
|
112 |
-
tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"])
|
113 |
-
tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"])
|
114 |
-
tokenized["labels"] = torch.tensor(label)
|
115 |
-
return tokenized
|
116 |
|
117 |
conversation_history = {}
|
118 |
|
@@ -121,22 +102,10 @@ async def process(request: Request):
|
|
121 |
data = await request.json()
|
122 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
123 |
|
124 |
-
tokenizers = {}
|
125 |
-
models = {}
|
126 |
-
|
127 |
-
model_name = "unified_model"
|
128 |
tokenizer_name = "unified_tokenizer"
|
129 |
|
130 |
-
model_data_bytes = redis_client.get(f"model:{model_name}")
|
131 |
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
|
132 |
|
133 |
-
if model_data_bytes:
|
134 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
135 |
-
model.load_state_dict(torch.load(model_data_bytes))
|
136 |
-
else:
|
137 |
-
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
|
138 |
-
models[model_name] = model
|
139 |
-
|
140 |
if tokenizer_data_bytes:
|
141 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
142 |
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
@@ -144,9 +113,8 @@ async def process(request: Request):
|
|
144 |
else:
|
145 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
146 |
tokenizer.pad_token = tokenizer.eos_token
|
147 |
-
tokenizers[tokenizer_name] = tokenizer
|
148 |
|
149 |
-
unified_model = UnifiedModel
|
150 |
unified_model.to(torch.device("cpu"))
|
151 |
|
152 |
if data.get("train"):
|
@@ -170,11 +138,9 @@ async def process(request: Request):
|
|
170 |
conversation_history[user_id] = []
|
171 |
conversation_history[user_id].append(text)
|
172 |
contextualized_text = " ".join(conversation_history[user_id][-3:])
|
173 |
-
|
174 |
-
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
|
175 |
-
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
|
176 |
with torch.no_grad():
|
177 |
-
logits = unified_model(
|
178 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
179 |
response = chatbot_service.get_response(user_id, contextualized_text, language)
|
180 |
redis_client.rpush("training_queue", json.dumps({
|
@@ -327,35 +293,26 @@ def train_unified_model():
|
|
327 |
if training_queue:
|
328 |
for item in training_queue:
|
329 |
item_data = json.loads(item)
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
333 |
data = item_data["data"]
|
334 |
-
dataset = SyntheticDataset(
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
optimizer.zero_grad()
|
349 |
-
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
350 |
-
loss = criterion(outputs, labels)
|
351 |
-
loss.backward()
|
352 |
-
optimizer.step()
|
353 |
-
|
354 |
-
model_data_path = "model_data.pt"
|
355 |
-
torch.save(model.state_dict(), model_data_path)
|
356 |
-
with open(model_data_path, "rb") as f:
|
357 |
-
model_data_bytes = f.read()
|
358 |
-
redis_client.set(f"model:unified_model", model_data_bytes)
|
359 |
redis_client.delete("training_queue")
|
360 |
time.sleep(60)
|
361 |
|
|
|
6 |
AutoTokenizer,
|
7 |
AutoModelForSequenceClassification,
|
8 |
AutoModelForCausalLM,
|
9 |
+
TrainingArguments,
|
10 |
+
Trainer,
|
11 |
)
|
|
|
|
|
|
|
|
|
12 |
from fastapi import FastAPI, HTTPException, Request
|
13 |
from fastapi.responses import HTMLResponse
|
14 |
import multiprocessing
|
|
|
63 |
|
64 |
chatbot_service = ChatbotService()
|
65 |
|
66 |
+
class UnifiedModel(AutoModelForSequenceClassification):
|
67 |
+
def __init__(self, config):
|
68 |
+
super().__init__(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
@staticmethod
|
71 |
def load_model_from_redis(redis_client):
|
72 |
model_name = "unified_model"
|
73 |
+
model_path = f"models/{model_name}"
|
74 |
+
if redis_client.exists(f"model:{model_name}"):
|
75 |
+
redis_client.delete(f"model:{model_name}")
|
76 |
+
if not os.path.exists(model_path):
|
77 |
+
model = UnifiedModel.from_pretrained("gpt2", num_labels=3)
|
78 |
+
model.save_pretrained(model_path)
|
79 |
else:
|
80 |
+
model = UnifiedModel.from_pretrained(model_path)
|
81 |
+
return model
|
82 |
|
83 |
class SyntheticDataset(Dataset):
|
84 |
+
def __init__(self, tokenizer, data):
|
85 |
+
self.tokenizer = tokenizer
|
86 |
self.data = data
|
87 |
|
88 |
def __len__(self):
|
|
|
92 |
item = self.data[idx]
|
93 |
text = item['text']
|
94 |
label = item['label']
|
95 |
+
tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
|
96 |
+
return {"input_ids": tokens["input_ids"].squeeze(), "attention_mask": tokens["attention_mask"].squeeze(), "labels": label}
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
conversation_history = {}
|
99 |
|
|
|
102 |
data = await request.json()
|
103 |
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
|
104 |
|
|
|
|
|
|
|
|
|
105 |
tokenizer_name = "unified_tokenizer"
|
106 |
|
|
|
107 |
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
if tokenizer_data_bytes:
|
110 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
111 |
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
|
|
|
113 |
else:
|
114 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
115 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
116 |
|
117 |
+
unified_model = UnifiedModel.load_model_from_redis(redis_client)
|
118 |
unified_model.to(torch.device("cpu"))
|
119 |
|
120 |
if data.get("train"):
|
|
|
138 |
conversation_history[user_id] = []
|
139 |
conversation_history[user_id].append(text)
|
140 |
contextualized_text = " ".join(conversation_history[user_id][-3:])
|
141 |
+
tokenized_input = tokenizer(contextualized_text, return_tensors="pt")
|
|
|
|
|
142 |
with torch.no_grad():
|
143 |
+
logits = unified_model(**tokenized_input).logits
|
144 |
predicted_class = torch.argmax(logits, dim=-1).item()
|
145 |
response = chatbot_service.get_response(user_id, contextualized_text, language)
|
146 |
redis_client.rpush("training_queue", json.dumps({
|
|
|
293 |
if training_queue:
|
294 |
for item in training_queue:
|
295 |
item_data = json.loads(item)
|
296 |
+
tokenizer_data = item_data["tokenizers"]
|
297 |
+
tokenizer_name = list(tokenizer_data.keys())[0]
|
298 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
299 |
+
tokenizer.add_tokens(json.loads(tokenizer_data[tokenizer_name]))
|
300 |
+
tokenizer.pad_token = tokenizer.eos_token
|
301 |
data = item_data["data"]
|
302 |
+
dataset = SyntheticDataset(tokenizer, data)
|
303 |
+
|
304 |
+
model_name = "unified_model"
|
305 |
+
model_path = f"models/{model_name}"
|
306 |
+
model = UnifiedModel.from_pretrained(model_path)
|
307 |
+
|
308 |
+
training_args = TrainingArguments(
|
309 |
+
output_dir="./results",
|
310 |
+
per_device_train_batch_size=8,
|
311 |
+
num_train_epochs=3,
|
312 |
+
)
|
313 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
314 |
+
trainer.train()
|
315 |
+
model.save_pretrained(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
redis_client.delete("training_queue")
|
317 |
time.sleep(60)
|
318 |
|