BeveledCube commited on
Commit
c36c2b7
1 Parent(s): 40235ed

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +7 -1
main.py CHANGED
@@ -13,6 +13,11 @@ name = "microsoft/DialoGPT-medium"
13
  model = GPT2LMHeadModel.from_pretrained(name)
14
  tokenizer = GPT2Tokenizer.from_pretrained(name)
15
 
 
 
 
 
 
16
  @app.route("/api", methods=["POST"])
17
  def receive_data():
18
  data = request.get_json()
@@ -33,7 +38,8 @@ def receive_data():
33
  print("Answered with:", answer_data)
34
  return jsonify(answer_data)
35
 
36
- @app.route("/", methods=["GET"])
 
37
  def not_api():
38
  return render_template("index.html")
39
 
 
13
  model = GPT2LMHeadModel.from_pretrained(name)
14
  tokenizer = GPT2Tokenizer.from_pretrained(name)
15
 
16
+ # Using CUDA for an optimal experience
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ model = model.to(device)
19
+
20
+ # Open a thing for the API
21
  @app.route("/api", methods=["POST"])
22
  def receive_data():
23
  data = request.get_json()
 
38
  print("Answered with:", answer_data)
39
  return jsonify(answer_data)
40
 
41
+ # Incase a normal browser opens the page
42
+ @app.route("/")
43
  def not_api():
44
  return render_template("index.html")
45