Tonic commited on
Commit
ea7c9d2
1 Parent(s): 7c96374

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -21,7 +21,7 @@ model_name = "OpenLLM-France/Claire-7B-0.1"
21
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
22
  model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
23
  device_map="auto",
24
- torch_dtype=torch.bfloat16.to("cuda"),
25
  load_in_4bit=True # For efficient inference, if supported by the GPU card
26
  )
27
  model = model.to_bettertransformer()
@@ -58,6 +58,7 @@ class FalconChatBot:
58
  conversation = f"{self.system_prompt}\n {assistant_message if assistant_message else ''}\n {user_message}\n "
59
  # Encode the conversation using the tokenizer
60
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
 
61
  # Generate a response using the Falcon model
62
  response = model.generate(
63
  input_ids=input_ids,
@@ -76,7 +77,6 @@ class FalconChatBot:
76
 
77
  # Decode the generated response to text
78
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)
79
-
80
  # Update and return the history with the new conversation
81
  updated_history = processed_history + [{"user": user_message, "assistant": response_text}]
82
  return response_text, updated_history
 
21
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
22
  model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
23
  device_map="auto",
24
+ torch_dtype=torch.bfloat16
25
  load_in_4bit=True # For efficient inference, if supported by the GPU card
26
  )
27
  model = model.to_bettertransformer()
 
58
  conversation = f"{self.system_prompt}\n {assistant_message if assistant_message else ''}\n {user_message}\n "
59
  # Encode the conversation using the tokenizer
60
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
61
+ input_ids = input_ids.to(device)
62
  # Generate a response using the Falcon model
63
  response = model.generate(
64
  input_ids=input_ids,
 
77
 
78
  # Decode the generated response to text
79
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)
 
80
  # Update and return the history with the new conversation
81
  updated_history = processed_history + [{"user": user_message, "assistant": response_text}]
82
  return response_text, updated_history