Spaces:
Runtime error
Runtime error
ehristoforu
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Iterator
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
8 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
9 |
from peft import AutoPeftModelForCausalLM
|
10 |
|
11 |
DESCRIPTION = """\
|
@@ -22,9 +22,9 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
22 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
23 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
24 |
|
25 |
-
model_name = "
|
26 |
|
27 |
-
model =
|
28 |
model_name,
|
29 |
torch_dtype=torch.float16,
|
30 |
trust_remote_code=True
|
@@ -54,38 +54,26 @@ api.upload_folder(
|
|
54 |
@spaces.GPU(duration=60)
|
55 |
def generate(
|
56 |
message: str,
|
57 |
-
chat_history: list[
|
58 |
max_new_tokens: int = 1024,
|
59 |
temperature: float = 0.6,
|
60 |
top_p: float = 0.9,
|
61 |
top_k: int = 50,
|
62 |
repetition_penalty: float = 1.2,
|
63 |
) -> Iterator[str]:
|
64 |
-
conversation = []
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
conversation.append({"role": "user", "content": message})
|
73 |
-
|
74 |
-
formatted = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
|
75 |
-
inputs = tokenizer(formatted, return_tensors="pt", padding=True)
|
76 |
-
#if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
77 |
-
# input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
78 |
-
# gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
79 |
-
inputs = inputs.to(model.device)
|
80 |
-
attention_mask = inputs["attention_mask"]
|
81 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
82 |
generate_kwargs = dict(
|
83 |
-
{"input_ids":
|
84 |
streamer=streamer,
|
85 |
max_new_tokens=max_new_tokens,
|
86 |
-
#eos_token_id=tokenizer.eos_token_id,
|
87 |
-
pad_token_id=tokenizer.eos_token_id,
|
88 |
-
attention_mask=attention_mask,
|
89 |
do_sample=True,
|
90 |
top_p=top_p,
|
91 |
top_k=top_k,
|
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, LlamaForCausalLM
|
9 |
from peft import AutoPeftModelForCausalLM
|
10 |
|
11 |
DESCRIPTION = """\
|
|
|
22 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
23 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
24 |
|
25 |
+
model_name = "estrogen/c4ai-command-r7b-12-2024"
|
26 |
|
27 |
+
model = LlamaForCausalLM.from_pretrained(
|
28 |
model_name,
|
29 |
torch_dtype=torch.float16,
|
30 |
trust_remote_code=True
|
|
|
54 |
@spaces.GPU(duration=60)
|
55 |
def generate(
|
56 |
message: str,
|
57 |
+
chat_history: list[dict],
|
58 |
max_new_tokens: int = 1024,
|
59 |
temperature: float = 0.6,
|
60 |
top_p: float = 0.9,
|
61 |
top_k: int = 50,
|
62 |
repetition_penalty: float = 1.2,
|
63 |
) -> Iterator[str]:
|
64 |
+
conversation = [*chat_history, {"role": "user", "content": message}]
|
65 |
+
|
66 |
+
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
67 |
+
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
68 |
+
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
69 |
+
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
70 |
+
input_ids = input_ids.to(model.device)
|
71 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
73 |
generate_kwargs = dict(
|
74 |
+
{"input_ids": input_ids},
|
75 |
streamer=streamer,
|
76 |
max_new_tokens=max_new_tokens,
|
|
|
|
|
|
|
77 |
do_sample=True,
|
78 |
top_p=top_p,
|
79 |
top_k=top_k,
|