BenBranyon commited on
Commit
05660ce
·
verified ·
1 Parent(s): 50f12a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -1
app.py CHANGED
@@ -13,7 +13,7 @@ if torch.cuda.is_available():
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
  tokenizer.use_default_system_prompt = False
15
 
16
-
17
  def respond(
18
  message,
19
  history: list[tuple[str, str]],
@@ -45,6 +45,50 @@ def respond(
45
  response += token
46
  yield response
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  demo = gr.ChatInterface(
50
  respond,
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
  tokenizer.use_default_system_prompt = False
15
 
16
+ #Inference API Code
17
  def respond(
18
  message,
19
  history: list[tuple[str, str]],
 
45
  response += token
46
  yield response
47
 
48
+ #Transformers Code
49
+ @spaces.GPU
50
+ def generate(
51
+ message: str,
52
+ chat_history: list[tuple[str, str]],
53
+ system_prompt: str,
54
+ max_new_tokens: int = 1024,
55
+ temperature: float = 0.6,
56
+ top_p: float = 0.9,
57
+ top_k: int = 50,
58
+ repetition_penalty: float = 1.2,
59
+ ) -> Iterator[str]:
60
+ conversation = []
61
+ if system_prompt:
62
+ conversation.append({"role": "system", "content": system_prompt})
63
+ for user, assistant in chat_history:
64
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
65
+ conversation.append({"role": "user", "content": message})
66
+
67
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
68
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
69
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
70
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
71
+ input_ids = input_ids.to(model.device)
72
+
73
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
74
+ generate_kwargs = dict(
75
+ {"input_ids": input_ids},
76
+ streamer=streamer,
77
+ max_new_tokens=max_new_tokens,
78
+ do_sample=True,
79
+ top_p=top_p,
80
+ top_k=top_k,
81
+ temperature=temperature,
82
+ num_beams=1,
83
+ repetition_penalty=repetition_penalty,
84
+ )
85
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
86
+ t.start()
87
+
88
+ outputs = []
89
+ for text in streamer:
90
+ outputs.append(text)
91
+ yield "".join(outputs)
92
 
93
  demo = gr.ChatInterface(
94
  respond,