BeveledCube commited on
Commit
66ed14b
·
verified ·
1 Parent(s): 1a3bc85

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -12
main.py CHANGED
@@ -10,6 +10,8 @@ import torch
10
 
11
  app = FastAPI()
12
  name = "microsoft/DialoGPT-small"
 
 
13
 
14
  # microsoft/DialoGPT-small
15
  # microsoft/DialoGPT-medium
@@ -37,7 +39,7 @@ def read_root(data: req):
37
  print("Prompt:", data.prompt)
38
  print("Length:", data.length)
39
 
40
- if name == "microsoft/DialoGPT-small" or name == "microsoft/DialoGPT-medium" or name == "microsoft/DialoGPT-large":
41
  # tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
42
  # model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
43
 
@@ -58,16 +60,31 @@ def read_root(data: req):
58
 
59
  return answer_data
60
  else:
61
- input_text = data.prompt
 
62
 
63
- # Tokenize the input text
64
- input_ids = gpt2tokenizer.encode(input_text, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Generate output using the model
67
- output_ids = model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
68
- generated_text = gpt2tokenizer.decode(output_ids[0], skip_special_tokens=True)
69
-
70
- answer_data = { "answer": generated_text }
71
- print("Answer:", generated_text)
72
-
73
- return answer_data
 
 
 
 
10
 
11
  app = FastAPI()
12
  name = "microsoft/DialoGPT-small"
13
+ customGen = False
14
+ gpt2based = False
15
 
16
  # microsoft/DialoGPT-small
17
  # microsoft/DialoGPT-medium
 
39
  print("Prompt:", data.prompt)
40
  print("Length:", data.length)
41
 
42
+ if (name == "microsoft/DialoGPT-small" or name == "microsoft/DialoGPT-medium" or name == "microsoft/DialoGPT-large") and customGen == True:
43
  # tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
44
  # model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
45
 
 
60
 
61
  return answer_data
62
  else:
63
+ if gpt2based == True:
64
+ input_text = data.prompt
65
 
66
+ # Tokenize the input text
67
+ input_ids = gpt2tokenizer.encode(input_text, return_tensors="pt")
68
+
69
+ # Generate output using the model
70
+ output_ids = gpt2model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
71
+ generated_text = gpt2tokenizer.decode(output_ids[0], skip_special_tokens=True)
72
+
73
+ answer_data = { "answer": generated_text }
74
+ print("Answer:", generated_text)
75
+
76
+ return answer_data
77
+ else:
78
+ input_text = data.prompt
79
 
80
+ # Tokenize the input text
81
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
82
+
83
+ # Generate output using the model
84
+ output_ids = model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
85
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
86
+
87
+ answer_data = { "answer": generated_text }
88
+ print("Answer:", generated_text)
89
+
90
+ return answer_data