BeveledCube commited on
Commit
7a4525a
·
verified ·
1 Parent(s): 9da8faf

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -9
main.py CHANGED
@@ -9,6 +9,14 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
9
  import torch
10
 
11
  app = FastAPI()
 
 
 
 
 
 
 
 
12
 
13
  class req(BaseModel):
14
  prompt: str
@@ -20,15 +28,6 @@ def read_root():
20
 
21
  @app.post("/api")
22
  def read_root(data: req):
23
- name = "microsoft/DialoGPT-medium"
24
- # microsoft/DialoGPT-small
25
- # microsoft/DialoGPT-medium
26
- # microsoft/DialoGPT-large
27
-
28
- # Load the Hugging Face GPT-2 model and tokenizer
29
- model = GPT2LMHeadModel.from_pretrained(name)
30
- tokenizer = GPT2Tokenizer.from_pretrained(name)
31
-
32
  print("Prompt:", data.prompt)
33
  print("Length:", data.length)
34
 
 
9
  import torch
10
 
11
  app = FastAPI()
12
+ name = "microsoft/DialoGPT-medium"
13
+ # microsoft/DialoGPT-small
14
+ # microsoft/DialoGPT-medium
15
+ # microsoft/DialoGPT-large
16
+
17
+ # Load the Hugging Face GPT-2 model and tokenizer
18
+ model = GPT2LMHeadModel.from_pretrained(name)
19
+ tokenizer = GPT2Tokenizer.from_pretrained(name)
20
 
21
  class req(BaseModel):
22
  prompt: str
 
28
 
29
  @app.post("/api")
30
  def read_root(data: req):
 
 
 
 
 
 
 
 
 
31
  print("Prompt:", data.prompt)
32
  print("Length:", data.length)
33