tiendung commited on
Commit
cec26ce
1 Parent(s): 675b3d2
Files changed (3) hide show
  1. llm.py +261 -0
  2. text_utils.py +79 -0
  3. utils.py +44 -0
llm.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import utils; from utils import *
3
+ import os, sys, lzma, json, pprint, time, subprocess
4
+
5
+ thinker = os.getenv("thinker", "gemini")
6
+ TEMPERATURE = float(os.getenv("temperature", 0.1)) # 0.0 conservative (good for coding and correct syntax)
7
+
8
+ LLM_HOST = "gemini"
9
+ TKNZ_RATIO = 1
10
+
11
+ GEMINI_MODEL = 'gemini-1.5-pro-002'
12
+ FLASH_MODEL = 'gemini-1.5-flash-002'
13
+
14
+ # https://github.com/google-gemini/cookbook/blob/main/quickstarts/Prompting.ipynb
15
+ # https://github.com/google-gemini/cookbook/blob/main/quickstarts/Streaming.ipynb
16
+ import google.generativeai as genai # pip install -U -q google-generativeai
17
+ llm_log_filename = f"{location__}/data/llm.log"
18
+
19
+
20
+ genai.configure(api_key=os.getenv("GEMINI_FLASH_API_KEY"))
21
+
22
+ GEMINI_CLIENT = genai.GenerativeModel(GEMINI_MODEL, \
23
+ generation_config=genai.GenerationConfig(
24
+ max_output_tokens=1024*4,
25
+ temperature=TEMPERATURE
26
+ ))
27
+
28
+ def chat(prompt, history=[], use_cache=False, stream=False):
29
+ if stream: return GEMINI_CLIENT.generate_content(prompt, stream=True)
30
+
31
+ messages = history + [{"role": "user", "content": prompt}] # fake history
32
+ with open(llm_log_filename,"at") as f: f.write(f"\n- - - [ {GEMINI_MODEL} ] - - -\n\nPROMPT:\n{prompt}\n")
33
+
34
+ try:
35
+ res = GEMINI_CLIENT.generate_content(prompt, request_options = { "timeout": 6000 })
36
+ with open(llm_log_filename,"at") as f: f.write(f"\nRESPONSE:\n{res}\n"); f.write(f"\nCONTENT:\n{res.text}\n")
37
+ messages += [{"role": "assistant", "content": res.text}]
38
+ return messages
39
+
40
+ except Exception as e:
41
+ with open(llm_log_filename,"at") as f: f.write(f"\nEXCEPTION:\n{e}\n")
42
+ print(f"\nEXCEPTION:\n{e}\n"); raise e
43
+
44
+
45
+ FLASH_CLIENT = genai.GenerativeModel(FLASH_MODEL, \
46
+ generation_config=genai.GenerationConfig(
47
+ max_output_tokens=1024*8,
48
+ temperature=TEMPERATURE
49
+ ))
50
+
51
+ # def flash_chat(prompt, history=[], use_cache=False, stream=False):
52
+ # res = FLASH_CLIENT.generate_content(prompt)
53
+ # return [{"role": "assistant", "content": res.text}]
54
+ flash_chat = chat
55
+
56
+ def who_are_you():
57
+ print(f"{RED}{LLM_HOST}{RESET} " * 2)
58
+
59
+
60
+ if thinker == "gemini": # gemini pro
61
+ CTXLEN = 1024*64 # gemini thì vô tư, 128k hoặc 1m ctxlen đều OK
62
+ thinker_chat = chat
63
+
64
+ elif thinker in "70b|405b":
65
+ cache_filename = f"{location__}/data/thinker.jsonl.xz"
66
+ lock_filename = f"{location__}/data/thinker.lock"
67
+ log_filename = f"{location__}/data/thinker.log"
68
+
69
+ ## Load thinker_cache
70
+ lines = [] if not os.path.exists(cache_filename) else \
71
+ [ line for line in lzma.open(cache_filename,"rt") ]
72
+ assert len(lines) % 2 == 0
73
+ thinker_cache = {}; i = 0
74
+ while i < len(lines): # line có \n ở cuối nên [:-1] để bỏ đi
75
+ thinker_cache[lines[i][:-1]] = json.loads(lines[i+1])
76
+ i += 2
77
+ lines = None # Done loading
78
+
79
+ # https://docs.together.ai/docs/chat-models#hosted-models
80
+ model = {
81
+ "405b": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo 8k 3k 1.2", # $5.00 / 1m tokens(*)
82
+ "70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo 128k 4k 1.2", # $0.88 / 1m tokens(*)
83
+ }[thinker]
84
+
85
+ model, CTXLEN, MAX_TOKENS, TKNZ_RATIO = model.strip().split()
86
+ LLM_HOST = model
87
+
88
+ MAX_TOKENS = int(MAX_TOKENS[:-1])*1024
89
+ TKNZ_RATIO = float(TKNZ_RATIO)
90
+
91
+ CTXLEN = int(CTXLEN[:-1])
92
+ if CTXLEN > 32: CTXLEN = 32 # max 32k ctxlen
93
+ CTXLEN = CTXLEN*1024 - MAX_TOKENS
94
+ # print(model, CTXLEN, MAX_TOKENS, TKNZ_RATIO); input(); # DEBUG
95
+
96
+ from together import Together
97
+ together_client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
98
+ ###
99
+ stops = ["<|eot_id|>","<|eom_id|>","</answer>","</output>"]
100
+ def thinker_chat(prompt, history=[], stream=False, use_cache=True, testing=False):
101
+ if stream:
102
+ with open(log_filename,"at") as f: f.write(f"\n- - - [ {LLM_HOST} ] - - -\n\nPROMPT:\n{prompt}\n")
103
+ return together_client.chat.completions.create(
104
+ model=model,
105
+ messages=[{"role": "user", "content": prompt}],
106
+ max_tokens=MAX_TOKENS,
107
+ temperature=TEMPERATURE,
108
+ top_p=0.7, top_k=50,
109
+ repetition_penalty=1.2, stop=stops,
110
+ stream=True
111
+ )
112
+
113
+ messages = history + [{"role": "user", "content": prompt}]
114
+ messages_jsonl = json.dumps(messages, ensure_ascii=False)
115
+ cache_found = (messages_jsonl in thinker_cache)
116
+
117
+ if use_cache and cache_found:
118
+ print(f"{YELLOW}<<< cached content >>>{RESET}")
119
+ content = thinker_cache[messages_jsonl]
120
+
121
+ elif testing:
122
+ print(f"{RED}<<< testing content >>>{RESET}")
123
+ content = "testing testing"
124
+
125
+ else:
126
+ print(f"{GREEN}<<< fresh content >>>{RESET}")
127
+ with open(log_filename,"at") as f: f.write(f"\n- - - [ {LLM_HOST} ] - - -\n\nPROMPT:\n{prompt}\n")
128
+ try:
129
+ response = Together(api_key=os.environ.get('TOGETHER_API_KEY')).chat.completions.create(
130
+ model=model,
131
+ messages=messages,
132
+ max_tokens=MAX_TOKENS,
133
+ temperature=TEMPERATURE,
134
+ top_p=0.7, top_k=50,
135
+ repetition_penalty=1.2, stop=stops,
136
+ logprobs=1, stream=False
137
+ )
138
+ except Exception as e:
139
+ with open(log_filename,"at") as f: f.write(f"\nEXCEPTION:\n{e}\n")
140
+ print(f"\nEXCEPTION:\n{e}\n"); raise e
141
+
142
+ content = response.choices[0].message.content
143
+ with open(log_filename,"at") as f:
144
+ f.write(f"\nRESPONSE:\n{response}\n")
145
+ f.write(f"\nCONTENT:\n{content}\n")
146
+
147
+ thinker_cache[messages_jsonl] = content # update new generated content
148
+
149
+ waits = 5
150
+ while waits > 0 and os.path.exists(lock_filename): # có người đang write, wait
151
+ waits -= 1
152
+ time.sleep(0.2)
153
+
154
+ if waits == 0:
155
+ assert False, f"Bị lock hơn 1 second, có thể xóa {lock_filename} nếu lỗi này lặp lại"
156
+
157
+ subprocess.run(f"touch {lock_filename}", shell=True) # lock
158
+ with lzma.open(cache_filename,"at") as f: # write
159
+ f.write(f"{messages_jsonl}\n{json.dumps(content, ensure_ascii=False)}\n")
160
+ subprocess.run(f"rm {lock_filename}", shell=True) # unlock
161
+
162
+ messages += [{"role": "assistant", "content": content}]
163
+ return messages
164
+
165
+
166
+ elif thinker in "gemma2:27b|commandr:35b|llama3.1:70b":
167
+ #################
168
+ ## Ollama connect
169
+ import subprocess, ollama # pip install ollama
170
+ try: ollama.list()
171
+ except: subprocess.run('nohup ssh -N -L 11434:localhost:11434 -p 22021 [email protected] &', shell=True)
172
+ subprocess.run('nohup ssh -N -L 9999:localhost:11434 -p 17340 [email protected] &', shell=True)
173
+ #################
174
+ OLLAMA_CLIENT = ollama.Client(host='http://localhost:11434')
175
+ machine = "RTX-4090-24G"
176
+
177
+ ## ~30b models
178
+ if thinker in "gemma2:27b": OLLAMA_MODEL = "gemma2:27b-instruct-q5_K_M" ; CTXLEN = 512*14 # fit 24G
179
+ elif thinker in "commandr:35b": OLLAMA_MODEL = "command-r:35b-08-2024-q4_K_M" ; CTXLEN = 512*18 # fit 24G
180
+ else: OLLAMA_MODEL = "not found"
181
+
182
+ try: connect_to_4090 = OLLAMA_MODEL in str(ollama.list())
183
+ except: connect_to_4090 = False
184
+
185
+ if not connect_to_4090: # switch to A100
186
+ OLLAMA_CLIENT = ollama.Client(host='http://localhost:9999')
187
+ machine = "A100-PCIE-40GB"
188
+ ## ~30b to ~70b models
189
+ if thinker in "gemma2:27b": OLLAMA_MODEL = "gemma2:27b-instruct-q8_0" ; CTXLEN = 1024*24
190
+ elif thinker in "commandr:35b": OLLAMA_MODEL = "command-r:35b-08-2024-q8_0" ; CTXLEN = 1024*32
191
+ elif thinker in "llama3.1:70b": OLLAMA_MODEL = "llama3.1:70b-instruct-q3_K_M" ; CTXLEN = 1024*12#fit 40G
192
+ LLM_HOST = f"{machine}__{OLLAMA_MODEL}"
193
+
194
+ def thinker_chat(prompt, history=[], stream=False, use_cache=False):
195
+ if stream:
196
+ with open(llm_log_filename,"at") as f: f.write(f"\n- - - [ {LLM_HOST} ] - - -\n\nPROMPT:\n{prompt}\n")
197
+ return OLLAMA_CLIENT.chat(model=OLLAMA_MODEL, messages=[{"role": "user", "content": prompt}], \
198
+ stream=True, options={'num_ctx': CTXLEN, 'temperature': TEMPERATURE})
199
+
200
+ messages = history + [{"role": "user", "content": prompt}]
201
+ with open(llm_log_filename,"at") as f: f.write(f"\n- - - [ {LLM_HOST} ] - - -\n\nPROMPT:\n{prompt}\n")
202
+ res = OLLAMA_CLIENT.chat(model=OLLAMA_MODEL, messages=messages, options={'temperature': TEMPERATURE})
203
+ content = res["message"]["content"]
204
+ with open(llm_log_filename,"at") as f: f.write(f"\nCONTENT:\n{content}\n")
205
+ messages += [{"role": "assistant", "content": content}]
206
+ return messages
207
+
208
+ ## To make it's 100% local llm, normal chat can also use thinker
209
+ # chat = thinker_chat
210
+
211
+ LLM_HOST += f"__{round(CTXLEN/1024)}k_ctxlen"
212
+ who_are_you()
213
+
214
+
215
+
216
+ from prompts import summary_template
217
+ from prompts import contextual_template, clean_view_template
218
+
219
+ USE_CACHE = os.getenv("cache", "1") == "1"
220
+
221
+
222
+ def extract_keyphrases_figures_summary(text):
223
+ if len(text) < 80: return ""
224
+
225
+ prompt = summary_template.format(text = text)
226
+ print(f"{GREEN}{text}{RESET}")
227
+
228
+ utils.reset_timer(timer = "extract_keyphrases_figures_summary")
229
+ res = chat(prompt, use_cache = USE_CACHE)
230
+ utils.measure_time("", timer = "extract_keyphrases_figures_summary")
231
+
232
+ raw = res[-1]["content"]
233
+ print(f"{MAGENTA}{raw}{RESET}")
234
+
235
+ return raw
236
+
237
+
238
+ def gen_contextual(document, chunk):
239
+ prompt = contextual_template.format(document = document, chunk = chunk)
240
+ res = thinker_chat(prompt, use_cache = USE_CACHE)
241
+ contextual = res[-1]["content"].strip()
242
+ return contextual
243
+
244
+
245
+ def gen_clean_view(document):
246
+ prompt = clean_view_template.format(document = document)
247
+ res = chat(prompt, use_cache = USE_CACHE)
248
+ ret = res[-1]["content"].strip()
249
+ return ret
250
+
251
+
252
+ if __name__ == "__main__":
253
+
254
+ try: filename = sys.argv[1]
255
+ except: filename = None
256
+ if filename: q = open(filename, "rt").read()
257
+ else: q = "What's your name? Who created you?"
258
+
259
+ utils.reset_timer(); res = thinker_chat(q, use_cache=False)
260
+ utils.measure_time(LLM_HOST + " ")
261
+ print(f"{CYAN}{q}{RESET}", end="\n\n"); print(res[-1]["content"])
text_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, os, sys
2
+ from utils import *
3
+
4
+ def get_paragraphs(text, cutoff=10):
5
+ return [ x.strip() for x in re.split(r'\n+', text, flags=re.MULTILINE) if len(x.strip()) > cutoff ]
6
+
7
+ def get_para_sentences(text, cutoff=60):
8
+ para_sents = []
9
+ for para in get_paragraphs(text):
10
+ sents = []; sent = ""
11
+ chunks = re.split(r'\.+', para); n = len(chunks)
12
+ for i in range(0, n):
13
+ sent += chunks[i]
14
+ if i < n - 1: sent += "."
15
+ if len(sent) > cutoff:
16
+ sents.append(sent)
17
+ sent = ""
18
+ if len(sent) > 0: sents.append(sent)
19
+ # print(sents); input()
20
+ para_sents.append(sents)
21
+ return para_sents
22
+
23
+ def get_idx_from_marked_chunk(marked_chunk):
24
+ return int(re.match(r'<C\s*(\d+)>', marked_chunk)[1])
25
+ import random; idx = random.randint(0, 99999)
26
+ assert get_idx_from_marked_chunk(f"<C {idx}> ha ha") == idx
27
+
28
+
29
+ def add_chunk_markers(text, lookup_idx = None, para = True):
30
+ if para: para_chunks = get_paragraphs(text)
31
+ else: para_chunks = get_para_sentences(text)
32
+
33
+ marked_text = ""; chunk_idx = 0
34
+ for chunks in para_chunks:
35
+ if isinstance(chunks, str): chunks = [chunks]
36
+ for idx, chunk in enumerate(chunks):
37
+ marked_chunk = f"<C {chunk_idx}>{chunk.strip()}"
38
+
39
+ chunks[idx] = marked_chunk
40
+ if lookup_idx == chunk_idx: print(marked_chunk); sys.exit() # assert False, f"Đã tìm thấy {lookup_idx}"
41
+
42
+ marked_text += f"{marked_chunk}\n"
43
+ chunk_idx += 1
44
+ marked_text += "\n"
45
+ return marked_text.strip(), para_chunks
46
+
47
+
48
+ alphabet = '[0-9a-zaăâáắấàằầảẳẩãẵẫạặậđeêéếèềẻểẽễẹệiíìỉĩịoôơóốớòồờỏổởõỗỡọộợuưúứùừủửũữụựyýỳỷỹỵaăâáắấàằầảẳẩãẵẫạặậđeêéếèềẻểẽễẹệiíìỉĩịoôơóốớòồờỏổởõỗỡọộợuưúứùừủửũữụựyýỳỷỹỵaăâáắấàằầảẳẩãẵẫạặậđeêéếèềẻểẽễẹệiíìỉĩịoôơóốớòồờỏổởõỗỡọộợuưúứùừủửũữụựyýỳỷỹỵaăâáắấàằầảẳẩãẵẫạặậđeêéếèềẻểẽễẹệiíìỉĩịoôơóốớòồờỏổởõỗỡọộợuưúứùừủửũữụựyýỳỷỹỵaăâáắấàằầảẳẩãẵẫạặậđeêéếèềẻểẽễẹệiíìỉĩịoôơóốớòồờỏổởõỗỡọộợuưúứùừủửũữụựyýỳỷỹỵaăâáắấàằầảẳẩãẵẫạặậđeêéếèềẻểẽễẹệiíìỉĩịoôơóốớòồờỏổởõỗỡọộợuưúứùừủửũữụựyýỳỷỹỵ]'
49
+ word = re.compile(f'{alphabet}+', re.IGNORECASE)
50
+ ###
51
+ def hilite(query, source, hilite_color=YELLOW, source_color=GREY, query_color=None):
52
+ for keyword in set(re.findall(word, query)):
53
+ keyword = re.escape(keyword)
54
+ re_keyword = re.compile(rf"(\b{keyword}\b)", flags=re.IGNORECASE | re.MULTILINE)
55
+ if re_keyword.search(source):
56
+ source = re.sub(re_keyword, rf'{hilite_color}\1{source_color}', source)
57
+ if query_color is not None:
58
+ query = re.sub(re_keyword, rf'{hilite_color}\1{query_color}', query)
59
+ return source, query
60
+
61
+
62
+ def pretty_num(x):
63
+ return round(x*100)/100
64
+
65
+ def count_words(x):
66
+ assert isinstance(x, str), f"đầu không phải string {x}"
67
+ return len(x.split())
68
+
69
+ def extract_(text, tag):
70
+ raw = text.split(f"</{tag}>")[0].split(f"<{tag}>")[-1]
71
+ if tag == "summary": return raw.strip()
72
+ splits = re.split(r'[\n,]+', raw)
73
+ splits = [ re.sub(r'^\s*-\s*', '', s).strip() for s in splits ]
74
+ splits = [ s for s in splits if len(s) > 0 ]
75
+ return splits
76
+
77
+ def extract_xmls(text, tags):
78
+ if text is None: return None
79
+ return { tag: extract_(text, tag) for tag in tags }
utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time, os
2
+ location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
3
+
4
+ ## Các màu hay dùng
5
+ BLACK = '\033[30m'; WHITE = '\033[97m'
6
+ RED = '\033[91m'; YELLOW = '\033[33m'
7
+ GREEN = '\033[32m'; CYAN = '\033[36m'
8
+ BLUE = '\033[94m'; GREY = '\033[37m'
9
+ MAGENTA = '\033[95m'; RESET = '\033[0m'
10
+
11
+ # Màu đậm # Màu sáng # Màu nền đậm # Màu nền sáng
12
+ DK = '\033[30m'; BK = '\033[90m'; GDK = '\033[40m'; GBK = '\033[100m'; # BLACK
13
+ DR = '\033[31m'; BR = '\033[91m'; GDR = '\033[41m'; GBR = '\033[101m'; # RED
14
+ DG = '\033[32m'; BG = '\033[92m'; GDG = '\033[42m'; GBG = '\033[102m'; # GREEN
15
+ DY = '\033[33m'; BY = '\033[93m'; GDY = '\033[43m'; GBY = '\033[103m' # YELLOW
16
+ DB = '\033[34m'; BB = '\033[94m'; GDB = '\033[44m'; GBB = '\033[104m'; # BLUE
17
+ DM = '\033[35m'; BM = '\033[95m'; GDM = '\033[45m'; GBM = '\033[105m'; # MAGENTA (tím hồng)
18
+ DC = '\033[36m'; BC = '\033[96m'; GDC = '\033[46m'; GBC = '\033[106m'; # CYAN
19
+ DW = '\033[37m'; BW = '\033[97m'; GDW = '\033[47m'; GBW = '\033[107m'; # WHITE
20
+
21
+ def pretty_num(x): return round(x*100)/100
22
+
23
+ TIMER_STARTED_AT = { "default": time.time() }
24
+ def reset_timer(timer="default"):
25
+ global TIMER_STARTED_AT
26
+ TIMER_STARTED_AT[timer] = time.time()
27
+
28
+ def measure_time(message="", timer="default", color=YELLOW):
29
+ total = time.time() - TIMER_STARTED_AT[timer]
30
+ total = round(total * 100) / 100
31
+
32
+ message = message.strip()
33
+ if len(message) > 0:
34
+ message = " " + message
35
+
36
+ print(f"{color}{timer}:{message} {total} seconds{RESET}")
37
+
38
+ count_words = lambda x: len(x.split())
39
+
40
+ if __name__ == "__main__":
41
+ reset_timer(timer="my timer")
42
+ s = "chào cả nhà, cả nhà khỏe không ạ?"
43
+ print(f"{RED}{s}{RESET} có {CYAN}{count_words(s)} từ")
44
+ measure_time("tổng thời gian chạy", timer="my timer")