BeveledCube commited on
Commit
9920987
1 Parent(s): 781452b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -50
main.py CHANGED
@@ -9,9 +9,8 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, A
9
  import torch
10
 
11
  app = FastAPI()
12
- name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
13
  customGen = False
14
- gpt2based = False
15
 
16
  # microsoft/DialoGPT-small
17
  # microsoft/DialoGPT-medium
@@ -38,53 +37,17 @@ def read_root():
38
  def read_root(data: req):
39
  print("Prompt:", data.prompt)
40
  print("Length:", data.length)
41
-
42
- if (name == "microsoft/DialoGPT-small" or name == "microsoft/DialoGPT-medium" or name == "microsoft/DialoGPT-large") and customGen == True:
43
- # tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
44
- # model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
45
-
46
- step = 1
47
-
48
- # encode the new user input, add the eos_token and return a tensor in Pytorch
49
- new_user_input_ids = tokenizer.encode(data.prompt + tokenizer.eos_token, return_tensors='pt')
50
-
51
- # append the new user input tokens to the chat history
52
- bot_input_ids = torch.cat(new_user_input_ids, dim=-1) if step > 0 else new_user_input_ids
53
 
54
- # generated a response while limiting the total chat history to 1000 tokens,
55
- chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
56
 
57
- generated_text = tokenizer.decode(chat_history_ids[:, :][0], skip_special_tokens=True)
58
- answer_data = { "answer": generated_text }
59
- print("Answer:", generated_text)
60
-
61
- return answer_data
62
- else:
63
- if gpt2based == True:
64
- input_text = data.prompt
65
-
66
- # Tokenize the input text
67
- input_ids = gpt2tokenizer.encode(input_text, return_tensors="pt")
68
-
69
- # Generate output using the model
70
- output_ids = gpt2model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
71
- generated_text = gpt2tokenizer.decode(output_ids[0], skip_special_tokens=True)
72
-
73
- answer_data = { "answer": generated_text }
74
- print("Answer:", generated_text)
75
-
76
- return answer_data
77
- else:
78
- input_text = data.prompt
79
-
80
- # Tokenize the input text
81
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
82
-
83
- # Generate output using the model
84
- output_ids = model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
85
- generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
86
-
87
- answer_data = { "answer": generated_text }
88
- print("Answer:", generated_text)
89
-
90
- return answer_data
 
9
  import torch
10
 
11
  app = FastAPI()
12
+ name = "microsoft/DialoGPT-medium"
13
  customGen = False
 
14
 
15
  # microsoft/DialoGPT-small
16
  # microsoft/DialoGPT-medium
 
37
  def read_root(data: req):
38
  print("Prompt:", data.prompt)
39
  print("Length:", data.length)
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ input_text = data.prompt
 
42
 
43
+ # Tokenize the input text
44
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
45
+
46
+ # Generate output using the model
47
+ output_ids = model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
48
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
49
+
50
+ answer_data = { "answer": generated_text }
51
+ print("Answer:", generated_text)
52
+
53
+ return answer_data