Yjhhh commited on
Commit
e9a0d8b
·
verified ·
1 Parent(s): 41894e4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +5 -10
main.py CHANGED
@@ -1,7 +1,6 @@
1
  from dotenv import load_dotenv
2
  import os
3
  import json
4
- import requests
5
  import redis
6
  from transformers import (
7
  AutoTokenizer,
@@ -186,12 +185,6 @@ async def process(request: Request):
186
  else:
187
  raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
188
 
189
- def get_chatbot_response(user_id, question, predicted_class, language):
190
- if user_id not in conversation_history:
191
- conversation_history[user_id] = []
192
- conversation_history[user_id].append(question)
193
- return chatbot_service.get_response(user_id, question, language)
194
-
195
  @app.get("/")
196
  async def get_home():
197
  user_id = str(uuid.uuid4())
@@ -357,9 +350,11 @@ def train_unified_model():
357
  loss.backward()
358
  optimizer.step()
359
 
360
- model_data_bytes = torch.save(model.state_dict(), "model_data.pt")
 
 
 
361
  redis_client.set(f"model:unified_model", model_data_bytes)
362
-
363
  redis_client.delete("training_queue")
364
  time.sleep(60)
365
 
@@ -367,4 +362,4 @@ if __name__ == "__main__":
367
  training_process = multiprocessing.Process(target=train_unified_model)
368
  training_process.start()
369
  import uvicorn
370
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from dotenv import load_dotenv
2
  import os
3
  import json
 
4
  import redis
5
  from transformers import (
6
  AutoTokenizer,
 
185
  else:
186
  raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
187
 
 
 
 
 
 
 
188
  @app.get("/")
189
  async def get_home():
190
  user_id = str(uuid.uuid4())
 
350
  loss.backward()
351
  optimizer.step()
352
 
353
+ model_data_path = "model_data.pt"
354
+ torch.save(model.state_dict(), model_data_path)
355
+ with open(model_data_path, "rb") as f:
356
+ model_data_bytes = f.read()
357
  redis_client.set(f"model:unified_model", model_data_bytes)
 
358
  redis_client.delete("training_queue")
359
  time.sleep(60)
360
 
 
362
  training_process = multiprocessing.Process(target=train_unified_model)
363
  training_process.start()
364
  import uvicorn
365
+ uvicorn.run(app, host="0.0.0.0", port=8000)