LPN64 commited on
Commit
e4dff56
·
verified ·
1 Parent(s): 9b6af1b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +180 -1
README.md CHANGED
@@ -5,4 +5,183 @@ base_model:
5
  ---
6
  GGUF version of longcite, you need to add the following tokens as stop tokens : `[128000, 128007, 128009]` or `["<|begin_of_text|>", "<|end_header_id|>", "<|eot_id|>"]`
7
 
8
- Be default, and it seems to be working so far, EOS token is 128007 (end_header_id). Working for citation and naive question-answer mode.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  ---
6
  GGUF version of longcite, you need to add the following tokens as stop tokens : `[128000, 128007, 128009]` or `["<|begin_of_text|>", "<|end_header_id|>", "<|eot_id|>"]`
7
 
8
+ Be default, and it seems to be working so far, EOS token is 128007 (end_header_id). Working for citation and naive question-answer mode.
9
+
10
+ Example code
11
+ ```python
12
+
13
+ from nltk.tokenize import PunktSentenceTokenizer
14
+ import re
15
+
16
+ class LongCiteModel:
17
+ @staticmethod
18
+ def text_split_by_punctuation(original_text, return_dict=False):
19
+ # text = re.sub(r'([a-z])\.([A-Z])', r'\1. \2', original_text) # separate period without space
20
+ text = original_text
21
+ custom_sent_tokenizer = PunktSentenceTokenizer(text)
22
+ punctuations = r"([。;!?])" # For Chinese support
23
+
24
+ separated = custom_sent_tokenizer.tokenize(text)
25
+ separated = sum([re.split(punctuations, s) for s in separated], [])
26
+ # Put the punctuations back to the sentence
27
+ for i in range(1, len(separated)):
28
+ if re.match(punctuations, separated[i]):
29
+ separated[i-1] += separated[i]
30
+ separated[i] = ''
31
+
32
+ separated = [s for s in separated if s != ""]
33
+ if len(separated) == 1:
34
+ separated = original_text.split('\n\n')
35
+ separated = [s.strip() for s in separated if s.strip() != ""]
36
+ if not return_dict:
37
+ return separated
38
+ else:
39
+ pos = 0
40
+ res = []
41
+ for i, sent in enumerate(separated):
42
+ st = original_text.find(sent, pos)
43
+ assert st != -1, sent
44
+ ed = st + len(sent)
45
+ res.append(
46
+ {
47
+ 'c_idx': i,
48
+ 'content': sent,
49
+ 'start_idx': st,
50
+ 'end_idx': ed,
51
+ }
52
+ )
53
+ pos = ed
54
+ return res
55
+
56
+ @staticmethod
57
+ def get_prompt(context, question):
58
+ sents = LongCiteModel.text_split_by_punctuation(context, return_dict=True)
59
+ splited_context = ""
60
+ for i, s in enumerate(sents):
61
+ st, ed = s['start_idx'], s['end_idx']
62
+ assert s['content'] == context[st:ed], s
63
+ ed = sents[i+1]['start_idx'] if i < len(sents)-1 else len(context)
64
+ sents[i] = {
65
+ 'content': context[st:ed],
66
+ 'start': st,
67
+ 'end': ed,
68
+ 'c_idx': s['c_idx'],
69
+ }
70
+ splited_context += f"<C{i}>"+context[st:ed]
71
+ prompt = '''Please answer the user's question based on the following document. When a sentence S in your response uses information from some chunks in the document (i.e., <C{s1}>-<C_{e1}>, <C{s2}>-<C{e2}>, ...), please append these chunk numbers to S in the format "<statement>{S}<cite>[{s1}-{e1}][{s2}-{e2}]...</cite></statement>". You must answer in the same language as the user's question.\n\n[Document Start]\n%s\n[Document End]\n\n%s''' % (splited_context, question)
72
+ return prompt, sents, splited_context
73
+
74
+ @staticmethod
75
+ def get_citations(statement, sents):
76
+ c_texts = re.findall(r'<cite>(.*?)</cite>', statement, re.DOTALL)
77
+ spans = sum([re.findall(r"\[([0-9]+\-[0-9]+)\]", c_text, re.DOTALL) for c_text in c_texts], [])
78
+ statement = re.sub(r'<cite>(.*?)</cite>', '', statement, flags=re.DOTALL)
79
+ merged_citations = []
80
+ for i, s in enumerate(spans):
81
+ try:
82
+ st, ed = [int(x) for x in s.split('-')]
83
+ if st > len(sents) - 1 or ed < st:
84
+ continue
85
+ st, ed = max(0, st), min(ed, len(sents)-1)
86
+ assert st <= ed, str(c_texts) + '\t' + str(len(sents))
87
+ if len(merged_citations) > 0 and st == merged_citations[-1]['end_sentence_idx'] + 1:
88
+ merged_citations[-1].update({
89
+ "end_sentence_idx": ed,
90
+ 'end_char_idx': sents[ed]['end'],
91
+ 'cite': ''.join([x['content'] for x in sents[merged_citations[-1]['start_sentence_idx']:ed+1]]),
92
+ })
93
+ else:
94
+ merged_citations.append({
95
+ "start_sentence_idx": st,
96
+ "end_sentence_idx": ed,
97
+ "start_char_idx": sents[st]['start'],
98
+ 'end_char_idx': sents[ed]['end'],
99
+ 'cite': ''.join([x['content'] for x in sents[st:ed+1]]),
100
+ })
101
+ except:
102
+ print(c_texts, len(sents), statement)
103
+ raise
104
+ return statement, merged_citations[:3]
105
+
106
+ @staticmethod
107
+ def postprocess(answer, sents, splited_context):
108
+ res = []
109
+ pos = 0
110
+ new_answer = ""
111
+ while True:
112
+ st = answer.find("<statement>", pos)
113
+ if st == -1:
114
+ st = len(answer)
115
+ ed = answer.find("</statement>", st)
116
+ statement = answer[pos:st]
117
+ if len(statement.strip()) > 5:
118
+ res.append({
119
+ "statement": statement,
120
+ "citation": []
121
+ })
122
+ new_answer += f"<statement>{statement}<cite></cite></statement>"
123
+ else:
124
+ res.append({
125
+ "statement": statement,
126
+ "citation": None,
127
+ })
128
+ new_answer += statement
129
+
130
+ if ed == -1:
131
+ break
132
+
133
+ statement = answer[st+len("<statement>"):ed]
134
+ if len(statement.strip()) > 0:
135
+ statement, citations = LongCiteModel.get_citations(statement, sents)
136
+ res.append({
137
+ "statement": statement,
138
+ "citation": citations
139
+ })
140
+ c_str = ''.join(['[{}-{}]'.format(c['start_sentence_idx'], c['end_sentence_idx']) for c in citations])
141
+ new_answer += f"<statement>{statement}<cite>{c_str}</cite></statement>"
142
+ else:
143
+ res.append({
144
+ "statement": statement,
145
+ "citation": None,
146
+ })
147
+ new_answer += statement
148
+ pos = ed + len("</statement>")
149
+ return {
150
+ "answer": new_answer.strip(),
151
+ "statements_with_citations": [x for x in res if x['citation'] is not None],
152
+ "splited_context": splited_context.strip(),
153
+ "all_statements": res,
154
+ }
155
+
156
+ @staticmethod
157
+ def truncate_from_middle(prompt, max_input_length=None, tokenizer=None):
158
+ if max_input_length is None:
159
+ return prompt
160
+ else:
161
+ assert tokenizer is not None
162
+ tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
163
+ if len(tokenized_prompt) > max_input_length:
164
+ half = int(max_input_length/2)
165
+ prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
166
+ return prompt
167
+
168
+
169
+
170
+
171
+ if __name__ == "__main__":
172
+
173
+ context = '''
174
+ your context
175
+ '''
176
+ query = "your user question here"
177
+ prompt, sents, splited_context = LongCiteModel.get_prompt(context, query)
178
+ print('Prompt:', prompt)
179
+ # add the Llama 3 tags to the prompt
180
+ max_input_length = 4096
181
+ output = "..." # what the llm returned
182
+ result = LongCiteModel.postprocess(output, sents, splited_context)
183
+
184
+
185
+
186
+
187
+ ```