Tuchuanhuhuhu commited on
Commit
33cbbdb
·
1 Parent(s): 5f0c62a

StableLM支持流式传输

Browse files
Files changed (1) hide show
  1. modules/models/StableLM.py +48 -58
modules/models/StableLM.py CHANGED
@@ -1,10 +1,14 @@
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList
3
  import time
4
  import numpy as np
5
  from torch.nn import functional as F
6
  import os
7
  from .base_model import BaseLLMModel
 
 
 
 
8
 
9
  class StopOnTokens(StoppingCriteria):
10
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
@@ -17,11 +21,18 @@ class StopOnTokens(StoppingCriteria):
17
  class StableLM_Client(BaseLLMModel):
18
  def __init__(self, model_name) -> None:
19
  super().__init__(model_name=model_name)
 
20
  print(f"Starting to load StableLM to memory")
21
- self.model = AutoModelForCausalLM.from_pretrained(
22
- "stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
23
- self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b")
24
- self.generator = pipeline('text-generation', model=self.model, tokenizer=self.tokenizer, device=0)
 
 
 
 
 
 
25
  print(f"Sucessfully loaded StableLM to the memory")
26
  self.system_prompt = """StableAssistant
27
  - StableAssistant is A helpful and harmless Open Source AI Language Model developed by Stability and CarperAI.
@@ -29,67 +40,46 @@ class StableLM_Client(BaseLLMModel):
29
  - StableAssistant is more than just an information source, StableAssistant is also able to write poetry, short stories, and make jokes.
30
  - StableAssistant will refuse to participate in anything that could harm a human."""
31
 
32
- def user(self, user_message, history):
33
- history = history + [[user_message, ""]]
34
- return "", history, history
35
-
36
-
37
- def bot(self, history, curr_system_message):
38
- messages = f"<|SYSTEM|># {self.system_prompt}" + \
39
- "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
40
- for item in history])
41
- output = self.generate(messages)
42
- history[-1][1] = output
43
- time.sleep(1)
44
- return history, history
45
-
46
  def _get_stablelm_style_input(self):
 
 
47
  messages = self.system_prompt + \
48
- "".join(["".join(["<|USER|>"+self.history[i]["content"], "<|ASSISTANT|>"+self.history[i + 1]["content"]])
49
- for i in range(0, len(self.history), 2)])
50
  return messages
51
 
52
- def generate(self, text, bad_text=None):
53
  stop = StopOnTokens()
54
  result = self.generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
55
  temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
56
  return result[0]["generated_text"].replace(text, "")
57
 
58
- def contrastive_generate(self, text, bad_text):
59
- with torch.no_grad():
60
- tokens = self.tokenizer(text, return_tensors="pt")[
61
- 'input_ids'].cuda()[:, :4096-1024]
62
- bad_tokens = self.tokenizer(bad_text, return_tensors="pt")[
63
- 'input_ids'].cuda()[:, :4096-1024]
64
- history = None
65
- bad_history = None
66
- curr_output = list()
67
- for i in range(1024):
68
- out = self.model(tokens, past_key_values=history, use_cache=True)
69
- logits = out.logits
70
- history = out.past_key_values
71
- bad_out = self.model(bad_tokens, past_key_values=bad_history,
72
- use_cache=True)
73
- bad_logits = bad_out.logits
74
- bad_history = bad_out.past_key_values
75
- probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
76
- bad_probs = F.softmax(bad_logits.float(), dim=-1)[0][-1].cpu()
77
- logits = torch.log(probs)
78
- bad_logits = torch.log(bad_probs)
79
- logits[probs > 0.1] = logits[probs > 0.1] - bad_logits[probs > 0.1]
80
- probs = F.softmax(logits)
81
- out = int(torch.multinomial(probs, 1))
82
- if out in [50278, 50279, 50277, 1, 0]:
83
- break
84
- else:
85
- curr_output.append(out)
86
- out = np.array([out])
87
- tokens = torch.from_numpy(np.array([out])).to(
88
- tokens.device)
89
- bad_tokens = torch.from_numpy(np.array([out])).to(
90
- tokens.device)
91
- return self.tokenizer.decode(curr_output)
92
-
93
  def get_answer_at_once(self):
94
  messages = self._get_stablelm_style_input()
95
- return self.generate(messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
3
  import time
4
  import numpy as np
5
  from torch.nn import functional as F
6
  import os
7
  from .base_model import BaseLLMModel
8
+ from threading import Thread
9
+
10
+ STABLELM_MODEL = None
11
+ STABLELM_TOKENIZER = None
12
 
13
  class StopOnTokens(StoppingCriteria):
14
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
21
  class StableLM_Client(BaseLLMModel):
22
  def __init__(self, model_name) -> None:
23
  super().__init__(model_name=model_name)
24
+ global STABLELM_MODEL, STABLELM_TOKENIZER
25
  print(f"Starting to load StableLM to memory")
26
+ if model_name == "StableLM":
27
+ model_name = "stabilityai/stablelm-tuned-alpha-7b"
28
+ else:
29
+ model_name = f"models/{model_name}"
30
+ if STABLELM_MODEL is None:
31
+ STABLELM_MODEL = AutoModelForCausalLM.from_pretrained(
32
+ model_name, torch_dtype=torch.float16).cuda()
33
+ if STABLELM_TOKENIZER is None:
34
+ STABLELM_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
35
+ self.generator = pipeline('text-generation', model=STABLELM_MODEL, tokenizer=STABLELM_TOKENIZER, device=0)
36
  print(f"Sucessfully loaded StableLM to the memory")
37
  self.system_prompt = """StableAssistant
38
  - StableAssistant is A helpful and harmless Open Source AI Language Model developed by Stability and CarperAI.
 
40
  - StableAssistant is more than just an information source, StableAssistant is also able to write poetry, short stories, and make jokes.
41
  - StableAssistant will refuse to participate in anything that could harm a human."""
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def _get_stablelm_style_input(self):
44
+ history = self.history + [{"role": "assistant", "content": ""}]
45
+ print(history)
46
  messages = self.system_prompt + \
47
+ "".join(["".join(["<|USER|>"+history[i]["content"], "<|ASSISTANT|>"+history[i + 1]["content"]])
48
+ for i in range(0, len(history), 2)])
49
  return messages
50
 
51
+ def _generate(self, text, bad_text=None):
52
  stop = StopOnTokens()
53
  result = self.generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
54
  temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
55
  return result[0]["generated_text"].replace(text, "")
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def get_answer_at_once(self):
58
  messages = self._get_stablelm_style_input()
59
+ return self._generate(messages), len(messages)
60
+
61
+ def get_answer_stream_iter(self):
62
+ stop = StopOnTokens()
63
+ messages = self._get_stablelm_style_input()
64
+
65
+ #model_inputs = tok([messages], return_tensors="pt")['input_ids'].cuda()[:, :4096-1024]
66
+ model_inputs = STABLELM_TOKENIZER([messages], return_tensors="pt").to("cuda")
67
+ streamer = TextIteratorStreamer(STABLELM_TOKENIZER, timeout=10., skip_prompt=True, skip_special_tokens=True)
68
+ generate_kwargs = dict(
69
+ model_inputs,
70
+ streamer=streamer,
71
+ max_new_tokens=1024,
72
+ do_sample=True,
73
+ top_p=0.95,
74
+ top_k=1000,
75
+ temperature=1.0,
76
+ num_beams=1,
77
+ stopping_criteria=StoppingCriteriaList([stop])
78
+ )
79
+ t = Thread(target=STABLELM_MODEL.generate, kwargs=generate_kwargs)
80
+ t.start()
81
+
82
+ partial_text = ""
83
+ for new_text in streamer:
84
+ partial_text += new_text
85
+ yield partial_text