Prot10 commited on
Commit
20cc109
·
1 Parent(s): fe08fd6

Model created

Browse files
Files changed (1) hide show
  1. app.py +162 -6
app.py CHANGED
@@ -1,15 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- sentiment = pipeline("sentiment-analysis")
 
5
 
6
- def get_sentiment(input_text):
7
- return sentiment(input_text)
 
 
 
 
 
8
 
9
- iface = gr.Interface(fn = get_sentiment,
10
  inputs = "text",
11
  outputs = ["text"],
12
- title = "Sentiment Analysis",
13
  description = "Ciao!!!")
14
 
15
  iface.launch(inline = False)
 
1
+ #import gradio as gr
2
+ #from transformers import pipeline
3
+
4
+ #sentiment = pipeline("sentiment-analysis")
5
+
6
+ #def get_sentiment(input_text):
7
+ # return sentiment(input_text)
8
+
9
+ #iface = gr.Interface(fn = get_sentiment,
10
+ # inputs = "text",
11
+ # outputs = ["text"],
12
+ # title = "Sentiment Analysis",
13
+ # description = "Ciao!!!")
14
+ #
15
+ #iface.launch(inline = False)
16
+
17
  import gradio as gr
18
+ from typing import *
19
+ import torch
20
+ import transformers
21
+
22
+ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
23
+
24
+ tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
25
+ model = LlamaForCausalLM.from_pretrained(
26
+ "decapoda-research/llama-7b-hf",
27
+ load_in_8bit=True,
28
+ device_map="auto",
29
+ )
30
+
31
+ def evaluate(question):
32
+ prompt = f"The conversation between human and AI assistant.\n[|Human|] {question}.\n[|AI|] "
33
+ inputs = tokenizer(question, return_tensors="pt")
34
+ input_ids = inputs["input_ids"].cuda()
35
+ generation_output = model.generate(
36
+ input_ids=input_ids,
37
+ generation_config=GenerationConfig(
38
+ temperature=1,
39
+ top_p=0.95,
40
+ num_beams=4,
41
+ max_context_length_tokens=2048,
42
+ ),
43
+ return_dict_in_generate=True,
44
+ output_scores=True,
45
+ max_new_tokens=512
46
+ )
47
+ output = tokenizer.decode(generation_output.sequences[0]).split("[|AI|]")[1]
48
+ return output
49
+
50
+
51
+ def generate_prompt_with_history(text:str, history: str, tokenizer, max_length=2048):
52
+ history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0],x[1]) for x in history]
53
+ history.append("\n[|Human|]{}\n[|AI|]".format(text))
54
+ history_text = ""
55
+
56
+ for x in history[::-1]:
57
+ if tokenizer(history_text + x, return_tensors="pt")['input_ids'].size(-1) <= max_length:
58
+ history_text = x + history_text
59
+ flag = True
60
+ if flag:
61
+ return history_text, tokenizer(history_text, return_tensors="pt")
62
+ else:
63
+ return False
64
+
65
+
66
+ def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
67
+ for stop_word in stop_words:
68
+ if s.endswith(stop_word):
69
+ return True
70
+ for i in range(1, len(stop_word)):
71
+ if s.endswith(stop_word[:i]):
72
+ return True
73
+ return False
74
+
75
+
76
+ def greedy_search(input_ids: torch.Tensor,
77
+ model: torch.nn.Module,
78
+ tokenizer: transformers.PreTrainedTokenizer,
79
+ stop_words: list,
80
+ max_length: int,
81
+ temperature: float = 1.0,
82
+ top_p: float = 1.0,
83
+ top_k: int = 25) -> Iterator[str]:
84
+ generated_tokens = []
85
+ past_key_values = None
86
+ current_length = 1
87
+ for i in range(max_length):
88
+ with torch.no_grad():
89
+ if past_key_values is None:
90
+ outputs = model(input_ids)
91
+ else:
92
+ outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
93
+ logits = outputs.logits[:, -1, :]
94
+ past_key_values = outputs.past_key_values
95
+
96
+ logits /= temperature
97
+
98
+ probs = torch.softmax(logits, dim=-1)
99
+
100
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
101
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
102
+ mask = probs_sum - probs_sort > top_p
103
+ probs_sort[mask] = 0.0
104
+
105
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
106
+ next_token = torch.multinomial(probs_sort, num_samples=1)
107
+ next_token = torch.gather(probs_idx, -1, next_token)
108
+
109
+ input_ids = torch.cat((input_ids, next_token), dim=-1)
110
+
111
+ generated_tokens.append(next_token[0].item())
112
+ text = tokenizer.decode(generated_tokens)
113
+
114
+ yield text
115
+ if any([x in text for x in stop_words]):
116
+ return
117
+ @torch.no_grad()
118
+
119
+
120
+ def predict(text:str,
121
+ chatbot,
122
+ history:str = "",
123
+ top_p:float = 0.95,
124
+ temperature:float = 1.0,
125
+ max_length_tokens:int = 512,
126
+ max_context_length_tokens:int = 2048):
127
+ if text=="":
128
+ return ""
129
+
130
+ inputs = generate_prompt_with_history(text, history, tokenizer, max_length=max_context_length_tokens)
131
+ prompt,inputs=inputs
132
+ begin_length = len(prompt)
133
+
134
+ input_ids = inputs["input_ids"].to(chatbot.device)
135
+ output = []
136
+
137
+ for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
138
+ if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
139
+ if "[|Human|]" in x:
140
+ x = x[:x.index("[|Human|]")].strip()
141
+ elif "[| Human |]" in x:
142
+ x = x[:x.index("[| Human |]")].strip()
143
+ if "[|AI|]" in x:
144
+ x = x[:x.index("[|AI|]")].strip()
145
+ x = x.strip(" ")
146
+ output.append(x)
147
+ return output[-1]
148
+
149
+ #text = "Can you give a more formal definition?"
150
+ #print(predict(text, model))
151
+
152
+ #sentiment = pipeline("sentiment-analysis")
153
 
154
+ #def get_sentiment(input_text):
155
+ # return sentiment(input_text)
156
 
157
+ #iface = gr.Interface(fn = get_sentiment,
158
+ # inputs = "text",
159
+ # outputs = ["text"],
160
+ # title = "Sentiment Analysis",
161
+ # description = "Ciao!!!")
162
+ #
163
+ #iface.launch(inline = False)
164
 
165
+ iface = gr.Interface(fn = predict,
166
  inputs = "text",
167
  outputs = ["text"],
168
+ title = "Learn with ChadGPT",
169
  description = "Ciao!!!")
170
 
171
  iface.launch(inline = False)