Sakalti commited on
Commit
4c541ae
·
verified ·
1 Parent(s): 67997f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -12,17 +12,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
12
  DESCRIPTION = "# Sakaltum-7B-chat"
13
 
14
  if not torch.cuda.is_available():
15
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
 
17
  MAX_MAX_NEW_TOKENS = 2048
18
  DEFAULT_MAX_NEW_TOKENS = 1024
19
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
20
 
 
21
  if torch.cuda.is_available():
22
- model_id = "sakaltcommunity/sakaltum-7b"
23
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
24
- model.eval()
25
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
26
 
27
 
28
  def apply_chat_template(conversation: list[dict[str, str]]) -> str:
@@ -31,7 +34,6 @@ def apply_chat_template(conversation: list[dict[str, str]]) -> str:
31
  return prompt
32
 
33
 
34
- @spaces.GPU
35
  @torch.inference_mode()
36
  def generate(
37
  message: str,
@@ -56,7 +58,7 @@ def generate(
56
 
57
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
58
  generate_kwargs = dict(
59
- {"input_ids": input_ids},
60
  streamer=streamer,
61
  max_new_tokens=max_new_tokens,
62
  do_sample=True,
 
12
  DESCRIPTION = "# Sakaltum-7B-chat"
13
 
14
  if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo might be slower on CPU.</p>"
16
 
17
  MAX_MAX_NEW_TOKENS = 2048
18
  DEFAULT_MAX_NEW_TOKENS = 1024
19
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
20
 
21
+ model_id = "sakaltcommunity/sakaltum-7b"
22
  if torch.cuda.is_available():
 
23
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
24
+ else:
25
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
26
+ model.eval()
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
29
 
30
 
31
  def apply_chat_template(conversation: list[dict[str, str]]) -> str:
 
34
  return prompt
35
 
36
 
 
37
  @torch.inference_mode()
38
  def generate(
39
  message: str,
 
58
 
59
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
60
  generate_kwargs = dict(
61
+ input_ids=input_ids,
62
  streamer=streamer,
63
  max_new_tokens=max_new_tokens,
64
  do_sample=True,