ranamhamoud commited on
Commit
469c0f9
·
verified ·
1 Parent(s): 1fa6ab9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -70,14 +70,15 @@ def generate(
70
  if model == "A":
71
  model = modelA
72
  tokenizer = tokenizerA
73
- enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
74
- input_ids = enc.input_ids.to(model.device)
75
 
76
  else:
77
  model = modelB
78
  tokenizer = tokenizerB
79
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
80
-
 
 
81
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
82
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
83
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
70
  if model == "A":
71
  model = modelA
72
  tokenizer = tokenizerA
73
+
 
74
 
75
  else:
76
  model = modelB
77
  tokenizer = tokenizerB
78
+
79
+ enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
80
+ input_ids = enc.input_ids.to(model.device)
81
+
82
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
83
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
84
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")