Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,17 +12,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
|
|
12 |
DESCRIPTION = "# Sakaltum-7B-chat"
|
13 |
|
14 |
if not torch.cuda.is_available():
|
15 |
-
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo
|
16 |
|
17 |
MAX_MAX_NEW_TOKENS = 2048
|
18 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
19 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
|
20 |
|
|
|
21 |
if torch.cuda.is_available():
|
22 |
-
model_id = "sakaltcommunity/sakaltum-7b"
|
23 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
def apply_chat_template(conversation: list[dict[str, str]]) -> str:
|
@@ -31,7 +34,6 @@ def apply_chat_template(conversation: list[dict[str, str]]) -> str:
|
|
31 |
return prompt
|
32 |
|
33 |
|
34 |
-
@spaces.GPU
|
35 |
@torch.inference_mode()
|
36 |
def generate(
|
37 |
message: str,
|
@@ -56,7 +58,7 @@ def generate(
|
|
56 |
|
57 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
58 |
generate_kwargs = dict(
|
59 |
-
|
60 |
streamer=streamer,
|
61 |
max_new_tokens=max_new_tokens,
|
62 |
do_sample=True,
|
|
|
12 |
DESCRIPTION = "# Sakaltum-7B-chat"
|
13 |
|
14 |
if not torch.cuda.is_available():
|
15 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo might be slower on CPU.</p>"
|
16 |
|
17 |
MAX_MAX_NEW_TOKENS = 2048
|
18 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
19 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
|
20 |
|
21 |
+
model_id = "sakaltcommunity/sakaltum-7b"
|
22 |
if torch.cuda.is_available():
|
|
|
23 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
|
24 |
+
else:
|
25 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
26 |
+
model.eval()
|
27 |
+
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
29 |
|
30 |
|
31 |
def apply_chat_template(conversation: list[dict[str, str]]) -> str:
|
|
|
34 |
return prompt
|
35 |
|
36 |
|
|
|
37 |
@torch.inference_mode()
|
38 |
def generate(
|
39 |
message: str,
|
|
|
58 |
|
59 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
60 |
generate_kwargs = dict(
|
61 |
+
input_ids=input_ids,
|
62 |
streamer=streamer,
|
63 |
max_new_tokens=max_new_tokens,
|
64 |
do_sample=True,
|