ehristoforu commited on
Commit
7c32a88
·
verified ·
1 Parent(s): 0975580

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -25
app.py CHANGED
@@ -5,7 +5,7 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
  from peft import AutoPeftModelForCausalLM
10
 
11
  DESCRIPTION = """\
@@ -22,9 +22,9 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
22
  HF_TOKEN = os.getenv("HF_TOKEN")
23
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
 
25
- model_name = "ehristoforu/BigFalcon3-from10B"
26
 
27
- model = AutoModelForCausalLM.from_pretrained(
28
  model_name,
29
  torch_dtype=torch.float16,
30
  trust_remote_code=True
@@ -54,38 +54,26 @@ api.upload_folder(
54
  @spaces.GPU(duration=60)
55
  def generate(
56
  message: str,
57
- chat_history: list[tuple[str, str]],
58
  max_new_tokens: int = 1024,
59
  temperature: float = 0.6,
60
  top_p: float = 0.9,
61
  top_k: int = 50,
62
  repetition_penalty: float = 1.2,
63
  ) -> Iterator[str]:
64
- conversation = []
65
- for user, assistant in chat_history:
66
- conversation.extend(
67
- [
68
- {"role": "user", "content": user},
69
- {"role": "assistant", "content": assistant},
70
- ]
71
- )
72
- conversation.append({"role": "user", "content": message})
73
-
74
- formatted = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
75
- inputs = tokenizer(formatted, return_tensors="pt", padding=True)
76
- #if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
77
- # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
78
- # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
79
- inputs = inputs.to(model.device)
80
- attention_mask = inputs["attention_mask"]
81
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
82
  generate_kwargs = dict(
83
- {"input_ids": inputs["input_ids"]},
84
  streamer=streamer,
85
  max_new_tokens=max_new_tokens,
86
- #eos_token_id=tokenizer.eos_token_id,
87
- pad_token_id=tokenizer.eos_token_id,
88
- attention_mask=attention_mask,
89
  do_sample=True,
90
  top_p=top_p,
91
  top_k=top_k,
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, LlamaForCausalLM
9
  from peft import AutoPeftModelForCausalLM
10
 
11
  DESCRIPTION = """\
 
22
  HF_TOKEN = os.getenv("HF_TOKEN")
23
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
 
25
+ model_name = "estrogen/c4ai-command-r7b-12-2024"
26
 
27
+ model = LlamaForCausalLM.from_pretrained(
28
  model_name,
29
  torch_dtype=torch.float16,
30
  trust_remote_code=True
 
54
  @spaces.GPU(duration=60)
55
  def generate(
56
  message: str,
57
+ chat_history: list[dict],
58
  max_new_tokens: int = 1024,
59
  temperature: float = 0.6,
60
  top_p: float = 0.9,
61
  top_k: int = 50,
62
  repetition_penalty: float = 1.2,
63
  ) -> Iterator[str]:
64
+ conversation = [*chat_history, {"role": "user", "content": message}]
65
+
66
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
67
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
68
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
69
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
70
+ input_ids = input_ids.to(model.device)
71
+
 
 
 
 
 
 
 
 
 
72
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
73
  generate_kwargs = dict(
74
+ {"input_ids": input_ids},
75
  streamer=streamer,
76
  max_new_tokens=max_new_tokens,
 
 
 
77
  do_sample=True,
78
  top_p=top_p,
79
  top_k=top_k,