Update README.md
Browse files
README.md
CHANGED
@@ -5,4 +5,183 @@ base_model:
|
|
5 |
---
|
6 |
GGUF version of longcite, you need to add the following tokens as stop tokens : `[128000, 128007, 128009]` or `["<|begin_of_text|>", "<|end_header_id|>", "<|eot_id|>"]`
|
7 |
|
8 |
-
Be default, and it seems to be working so far, EOS token is 128007 (end_header_id). Working for citation and naive question-answer mode.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
---
|
6 |
GGUF version of longcite, you need to add the following tokens as stop tokens : `[128000, 128007, 128009]` or `["<|begin_of_text|>", "<|end_header_id|>", "<|eot_id|>"]`
|
7 |
|
8 |
+
Be default, and it seems to be working so far, EOS token is 128007 (end_header_id). Working for citation and naive question-answer mode.
|
9 |
+
|
10 |
+
Example code
|
11 |
+
```python
|
12 |
+
|
13 |
+
from nltk.tokenize import PunktSentenceTokenizer
|
14 |
+
import re
|
15 |
+
|
16 |
+
class LongCiteModel:
|
17 |
+
@staticmethod
|
18 |
+
def text_split_by_punctuation(original_text, return_dict=False):
|
19 |
+
# text = re.sub(r'([a-z])\.([A-Z])', r'\1. \2', original_text) # separate period without space
|
20 |
+
text = original_text
|
21 |
+
custom_sent_tokenizer = PunktSentenceTokenizer(text)
|
22 |
+
punctuations = r"([。;!?])" # For Chinese support
|
23 |
+
|
24 |
+
separated = custom_sent_tokenizer.tokenize(text)
|
25 |
+
separated = sum([re.split(punctuations, s) for s in separated], [])
|
26 |
+
# Put the punctuations back to the sentence
|
27 |
+
for i in range(1, len(separated)):
|
28 |
+
if re.match(punctuations, separated[i]):
|
29 |
+
separated[i-1] += separated[i]
|
30 |
+
separated[i] = ''
|
31 |
+
|
32 |
+
separated = [s for s in separated if s != ""]
|
33 |
+
if len(separated) == 1:
|
34 |
+
separated = original_text.split('\n\n')
|
35 |
+
separated = [s.strip() for s in separated if s.strip() != ""]
|
36 |
+
if not return_dict:
|
37 |
+
return separated
|
38 |
+
else:
|
39 |
+
pos = 0
|
40 |
+
res = []
|
41 |
+
for i, sent in enumerate(separated):
|
42 |
+
st = original_text.find(sent, pos)
|
43 |
+
assert st != -1, sent
|
44 |
+
ed = st + len(sent)
|
45 |
+
res.append(
|
46 |
+
{
|
47 |
+
'c_idx': i,
|
48 |
+
'content': sent,
|
49 |
+
'start_idx': st,
|
50 |
+
'end_idx': ed,
|
51 |
+
}
|
52 |
+
)
|
53 |
+
pos = ed
|
54 |
+
return res
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def get_prompt(context, question):
|
58 |
+
sents = LongCiteModel.text_split_by_punctuation(context, return_dict=True)
|
59 |
+
splited_context = ""
|
60 |
+
for i, s in enumerate(sents):
|
61 |
+
st, ed = s['start_idx'], s['end_idx']
|
62 |
+
assert s['content'] == context[st:ed], s
|
63 |
+
ed = sents[i+1]['start_idx'] if i < len(sents)-1 else len(context)
|
64 |
+
sents[i] = {
|
65 |
+
'content': context[st:ed],
|
66 |
+
'start': st,
|
67 |
+
'end': ed,
|
68 |
+
'c_idx': s['c_idx'],
|
69 |
+
}
|
70 |
+
splited_context += f"<C{i}>"+context[st:ed]
|
71 |
+
prompt = '''Please answer the user's question based on the following document. When a sentence S in your response uses information from some chunks in the document (i.e., <C{s1}>-<C_{e1}>, <C{s2}>-<C{e2}>, ...), please append these chunk numbers to S in the format "<statement>{S}<cite>[{s1}-{e1}][{s2}-{e2}]...</cite></statement>". You must answer in the same language as the user's question.\n\n[Document Start]\n%s\n[Document End]\n\n%s''' % (splited_context, question)
|
72 |
+
return prompt, sents, splited_context
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def get_citations(statement, sents):
|
76 |
+
c_texts = re.findall(r'<cite>(.*?)</cite>', statement, re.DOTALL)
|
77 |
+
spans = sum([re.findall(r"\[([0-9]+\-[0-9]+)\]", c_text, re.DOTALL) for c_text in c_texts], [])
|
78 |
+
statement = re.sub(r'<cite>(.*?)</cite>', '', statement, flags=re.DOTALL)
|
79 |
+
merged_citations = []
|
80 |
+
for i, s in enumerate(spans):
|
81 |
+
try:
|
82 |
+
st, ed = [int(x) for x in s.split('-')]
|
83 |
+
if st > len(sents) - 1 or ed < st:
|
84 |
+
continue
|
85 |
+
st, ed = max(0, st), min(ed, len(sents)-1)
|
86 |
+
assert st <= ed, str(c_texts) + '\t' + str(len(sents))
|
87 |
+
if len(merged_citations) > 0 and st == merged_citations[-1]['end_sentence_idx'] + 1:
|
88 |
+
merged_citations[-1].update({
|
89 |
+
"end_sentence_idx": ed,
|
90 |
+
'end_char_idx': sents[ed]['end'],
|
91 |
+
'cite': ''.join([x['content'] for x in sents[merged_citations[-1]['start_sentence_idx']:ed+1]]),
|
92 |
+
})
|
93 |
+
else:
|
94 |
+
merged_citations.append({
|
95 |
+
"start_sentence_idx": st,
|
96 |
+
"end_sentence_idx": ed,
|
97 |
+
"start_char_idx": sents[st]['start'],
|
98 |
+
'end_char_idx': sents[ed]['end'],
|
99 |
+
'cite': ''.join([x['content'] for x in sents[st:ed+1]]),
|
100 |
+
})
|
101 |
+
except:
|
102 |
+
print(c_texts, len(sents), statement)
|
103 |
+
raise
|
104 |
+
return statement, merged_citations[:3]
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def postprocess(answer, sents, splited_context):
|
108 |
+
res = []
|
109 |
+
pos = 0
|
110 |
+
new_answer = ""
|
111 |
+
while True:
|
112 |
+
st = answer.find("<statement>", pos)
|
113 |
+
if st == -1:
|
114 |
+
st = len(answer)
|
115 |
+
ed = answer.find("</statement>", st)
|
116 |
+
statement = answer[pos:st]
|
117 |
+
if len(statement.strip()) > 5:
|
118 |
+
res.append({
|
119 |
+
"statement": statement,
|
120 |
+
"citation": []
|
121 |
+
})
|
122 |
+
new_answer += f"<statement>{statement}<cite></cite></statement>"
|
123 |
+
else:
|
124 |
+
res.append({
|
125 |
+
"statement": statement,
|
126 |
+
"citation": None,
|
127 |
+
})
|
128 |
+
new_answer += statement
|
129 |
+
|
130 |
+
if ed == -1:
|
131 |
+
break
|
132 |
+
|
133 |
+
statement = answer[st+len("<statement>"):ed]
|
134 |
+
if len(statement.strip()) > 0:
|
135 |
+
statement, citations = LongCiteModel.get_citations(statement, sents)
|
136 |
+
res.append({
|
137 |
+
"statement": statement,
|
138 |
+
"citation": citations
|
139 |
+
})
|
140 |
+
c_str = ''.join(['[{}-{}]'.format(c['start_sentence_idx'], c['end_sentence_idx']) for c in citations])
|
141 |
+
new_answer += f"<statement>{statement}<cite>{c_str}</cite></statement>"
|
142 |
+
else:
|
143 |
+
res.append({
|
144 |
+
"statement": statement,
|
145 |
+
"citation": None,
|
146 |
+
})
|
147 |
+
new_answer += statement
|
148 |
+
pos = ed + len("</statement>")
|
149 |
+
return {
|
150 |
+
"answer": new_answer.strip(),
|
151 |
+
"statements_with_citations": [x for x in res if x['citation'] is not None],
|
152 |
+
"splited_context": splited_context.strip(),
|
153 |
+
"all_statements": res,
|
154 |
+
}
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def truncate_from_middle(prompt, max_input_length=None, tokenizer=None):
|
158 |
+
if max_input_length is None:
|
159 |
+
return prompt
|
160 |
+
else:
|
161 |
+
assert tokenizer is not None
|
162 |
+
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
163 |
+
if len(tokenized_prompt) > max_input_length:
|
164 |
+
half = int(max_input_length/2)
|
165 |
+
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
|
166 |
+
return prompt
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
|
173 |
+
context = '''
|
174 |
+
your context
|
175 |
+
'''
|
176 |
+
query = "your user question here"
|
177 |
+
prompt, sents, splited_context = LongCiteModel.get_prompt(context, query)
|
178 |
+
print('Prompt:', prompt)
|
179 |
+
# add the Llama 3 tags to the prompt
|
180 |
+
max_input_length = 4096
|
181 |
+
output = "..." # what the llm returned
|
182 |
+
result = LongCiteModel.postprocess(output, sents, splited_context)
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
```
|