wenhu commited on
Commit
3c44d12
1 Parent(s): 71a2ab3

Update model/model_manager.py

Browse files
Files changed (1) hide show
  1. model/model_manager.py +1 -1
model/model_manager.py CHANGED
@@ -39,7 +39,7 @@ class ModelManager:
39
  @spaces.GPU(duration=30)
40
  def NSFW_filter(self, prompt):
41
  chat = [{"role": "user", "content": prompt}]
42
- input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
43
  self.guard.cuda()
44
  output = self.guard.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
45
  prompt_len = input_ids.shape[-1]
 
39
  @spaces.GPU(duration=30)
40
  def NSFW_filter(self, prompt):
41
  chat = [{"role": "user", "content": prompt}]
42
+ input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to('cuda')
43
  self.guard.cuda()
44
  output = self.guard.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
45
  prompt_len = input_ids.shape[-1]