Update main.py
Browse files
main.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
from dotenv import load_dotenv
|
2 |
import os
|
3 |
import json
|
4 |
-
import requests
|
5 |
import redis
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
@@ -186,12 +185,6 @@ async def process(request: Request):
|
|
186 |
else:
|
187 |
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
|
188 |
|
189 |
-
def get_chatbot_response(user_id, question, predicted_class, language):
|
190 |
-
if user_id not in conversation_history:
|
191 |
-
conversation_history[user_id] = []
|
192 |
-
conversation_history[user_id].append(question)
|
193 |
-
return chatbot_service.get_response(user_id, question, language)
|
194 |
-
|
195 |
@app.get("/")
|
196 |
async def get_home():
|
197 |
user_id = str(uuid.uuid4())
|
@@ -357,9 +350,11 @@ def train_unified_model():
|
|
357 |
loss.backward()
|
358 |
optimizer.step()
|
359 |
|
360 |
-
|
|
|
|
|
|
|
361 |
redis_client.set(f"model:unified_model", model_data_bytes)
|
362 |
-
|
363 |
redis_client.delete("training_queue")
|
364 |
time.sleep(60)
|
365 |
|
@@ -367,4 +362,4 @@ if __name__ == "__main__":
|
|
367 |
training_process = multiprocessing.Process(target=train_unified_model)
|
368 |
training_process.start()
|
369 |
import uvicorn
|
370 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
|
1 |
from dotenv import load_dotenv
|
2 |
import os
|
3 |
import json
|
|
|
4 |
import redis
|
5 |
from transformers import (
|
6 |
AutoTokenizer,
|
|
|
185 |
else:
|
186 |
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
@app.get("/")
|
189 |
async def get_home():
|
190 |
user_id = str(uuid.uuid4())
|
|
|
350 |
loss.backward()
|
351 |
optimizer.step()
|
352 |
|
353 |
+
model_data_path = "model_data.pt"
|
354 |
+
torch.save(model.state_dict(), model_data_path)
|
355 |
+
with open(model_data_path, "rb") as f:
|
356 |
+
model_data_bytes = f.read()
|
357 |
redis_client.set(f"model:unified_model", model_data_bytes)
|
|
|
358 |
redis_client.delete("training_queue")
|
359 |
time.sleep(60)
|
360 |
|
|
|
362 |
training_process = multiprocessing.Process(target=train_unified_model)
|
363 |
training_process.start()
|
364 |
import uvicorn
|
365 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|