Spaces:
Paused
Paused
import re | |
import json | |
import time | |
import copy | |
import elasticsearch | |
from elastic_transport import ConnectionTimeout | |
from elasticsearch import Elasticsearch | |
from elasticsearch_dsl import UpdateByQuery, Search, Index | |
from rag.settings import es_logger | |
from rag import settings | |
from rag.utils import singleton | |
es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__)) | |
class ESConnection: | |
def __init__(self): | |
self.info = {} | |
self.conn() | |
self.idxnm = settings.ES.get("index_name", "") | |
if not self.es.ping(): | |
raise Exception("Can't connect to ES cluster") | |
def conn(self): | |
for _ in range(10): | |
try: | |
self.es = Elasticsearch( | |
settings.ES["hosts"].split(","), | |
basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None, | |
verify_certs=False, | |
timeout=600 | |
) | |
if self.es: | |
self.info = self.es.info() | |
es_logger.info("Connect to es.") | |
break | |
except Exception as e: | |
es_logger.error("Fail to connect to es: " + str(e)) | |
time.sleep(1) | |
def version(self): | |
v = self.info.get("version", {"number": "5.6"}) | |
v = v["number"].split(".")[0] | |
return int(v) >= 7 | |
def health(self): | |
return dict(self.es.cluster.health()) | |
def upsert(self, df, idxnm=""): | |
res = [] | |
for d in df: | |
id = d["id"] | |
del d["id"] | |
d = {"doc": d, "doc_as_upsert": "true"} | |
T = False | |
for _ in range(10): | |
try: | |
if not self.version(): | |
r = self.es.update( | |
index=( | |
self.idxnm if not idxnm else idxnm), | |
body=d, | |
id=id, | |
doc_type="doc", | |
refresh=True, | |
retry_on_conflict=100) | |
else: | |
r = self.es.update( | |
index=( | |
self.idxnm if not idxnm else idxnm), | |
body=d, | |
id=id, | |
refresh=True, | |
retry_on_conflict=100) | |
es_logger.info("Successfully upsert: %s" % id) | |
T = True | |
break | |
except Exception as e: | |
es_logger.warning("Fail to index: " + | |
json.dumps(d, ensure_ascii=False) + str(e)) | |
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): | |
time.sleep(3) | |
continue | |
self.conn() | |
T = False | |
if not T: | |
res.append(d) | |
es_logger.error( | |
"Fail to index: " + | |
re.sub( | |
"[\r\n]", | |
"", | |
json.dumps( | |
d, | |
ensure_ascii=False))) | |
d["id"] = id | |
d["_index"] = self.idxnm | |
if not res: | |
return True | |
return False | |
def bulk(self, df, idx_nm=None): | |
ids, acts = {}, [] | |
for d in df: | |
id = d["id"] if "id" in d else d["_id"] | |
ids[id] = copy.deepcopy(d) | |
ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm | |
if "id" in d: | |
del d["id"] | |
if "_id" in d: | |
del d["_id"] | |
acts.append( | |
{"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100}) | |
acts.append({"doc": d, "doc_as_upsert": "true"}) | |
res = [] | |
for _ in range(100): | |
try: | |
if elasticsearch.__version__[0] < 8: | |
r = self.es.bulk( | |
index=( | |
self.idxnm if not idx_nm else idx_nm), | |
body=acts, | |
refresh=False, | |
timeout="600s") | |
else: | |
r = self.es.bulk(index=(self.idxnm if not idx_nm else | |
idx_nm), operations=acts, | |
refresh=False, timeout="600s") | |
if re.search(r"False", str(r["errors"]), re.IGNORECASE): | |
return res | |
for it in r["items"]: | |
if "error" in it["update"]: | |
res.append(str(it["update"]["_id"]) + | |
":" + str(it["update"]["error"])) | |
return res | |
except Exception as e: | |
es_logger.warn("Fail to bulk: " + str(e)) | |
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): | |
time.sleep(3) | |
continue | |
self.conn() | |
return res | |
def bulk4script(self, df): | |
ids, acts = {}, [] | |
for d in df: | |
id = d["id"] | |
ids[id] = copy.deepcopy(d["raw"]) | |
acts.append({"update": {"_id": id, "_index": self.idxnm}}) | |
acts.append(d["script"]) | |
es_logger.info("bulk upsert: %s" % id) | |
res = [] | |
for _ in range(10): | |
try: | |
if not self.version(): | |
r = self.es.bulk( | |
index=self.idxnm, | |
body=acts, | |
refresh=False, | |
timeout="600s", | |
doc_type="doc") | |
else: | |
r = self.es.bulk( | |
index=self.idxnm, | |
body=acts, | |
refresh=False, | |
timeout="600s") | |
if re.search(r"False", str(r["errors"]), re.IGNORECASE): | |
return res | |
for it in r["items"]: | |
if "error" in it["update"]: | |
res.append(str(it["update"]["_id"])) | |
return res | |
except Exception as e: | |
es_logger.warning("Fail to bulk: " + str(e)) | |
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): | |
time.sleep(3) | |
continue | |
self.conn() | |
return res | |
def rm(self, d): | |
for _ in range(10): | |
try: | |
if not self.version(): | |
r = self.es.delete( | |
index=self.idxnm, | |
id=d["id"], | |
doc_type="doc", | |
refresh=True) | |
else: | |
r = self.es.delete( | |
index=self.idxnm, | |
id=d["id"], | |
refresh=True, | |
doc_type="_doc") | |
es_logger.info("Remove %s" % d["id"]) | |
return True | |
except Exception as e: | |
es_logger.warn("Fail to delete: " + str(d) + str(e)) | |
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): | |
time.sleep(3) | |
continue | |
if re.search(r"(not_found)", str(e), re.IGNORECASE): | |
return True | |
self.conn() | |
es_logger.error("Fail to delete: " + str(d)) | |
return False | |
def search(self, q, idxnm=None, src=False, timeout="2s"): | |
if not isinstance(q, dict): | |
q = Search().query(q).to_dict() | |
for i in range(3): | |
try: | |
res = self.es.search(index=(self.idxnm if not idxnm else idxnm), | |
body=q, | |
timeout=timeout, | |
# search_type="dfs_query_then_fetch", | |
track_total_hits=True, | |
_source=src) | |
if str(res.get("timed_out", "")).lower() == "true": | |
raise Exception("Es Timeout.") | |
return res | |
except Exception as e: | |
es_logger.error( | |
"ES search exception: " + | |
str(e) + | |
"【Q】:" + | |
str(q)) | |
if str(e).find("Timeout") > 0: | |
continue | |
raise e | |
es_logger.error("ES search timeout for 3 times!") | |
raise Exception("ES search timeout.") | |
def sql(self, sql, fetch_size=128, format="json", timeout="2s"): | |
for i in range(3): | |
try: | |
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout) | |
return res | |
except ConnectionTimeout as e: | |
es_logger.error("Timeout【Q】:" + sql) | |
continue | |
except Exception as e: | |
raise e | |
es_logger.error("ES search timeout for 3 times!") | |
raise ConnectionTimeout() | |
def get(self, doc_id, idxnm=None): | |
for i in range(3): | |
try: | |
res = self.es.get(index=(self.idxnm if not idxnm else idxnm), | |
id=doc_id) | |
if str(res.get("timed_out", "")).lower() == "true": | |
raise Exception("Es Timeout.") | |
return res | |
except Exception as e: | |
es_logger.error( | |
"ES get exception: " + | |
str(e) + | |
"【Q】:" + | |
doc_id) | |
if str(e).find("Timeout") > 0: | |
continue | |
raise e | |
es_logger.error("ES search timeout for 3 times!") | |
raise Exception("ES search timeout.") | |
def updateByQuery(self, q, d): | |
ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q) | |
scripts = "" | |
for k, v in d.items(): | |
scripts += "ctx._source.%s = params.%s;" % (str(k), str(k)) | |
ubq = ubq.script(source=scripts, params=d) | |
ubq = ubq.params(refresh=False) | |
ubq = ubq.params(slices=5) | |
ubq = ubq.params(conflicts="proceed") | |
for i in range(3): | |
try: | |
r = ubq.execute() | |
return True | |
except Exception as e: | |
es_logger.error("ES updateByQuery exception: " + | |
str(e) + "【Q】:" + str(q.to_dict())) | |
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: | |
continue | |
self.conn() | |
return False | |
def updateScriptByQuery(self, q, scripts, idxnm=None): | |
ubq = UpdateByQuery( | |
index=self.idxnm if not idxnm else idxnm).using( | |
self.es).query(q) | |
ubq = ubq.script(source=scripts) | |
ubq = ubq.params(refresh=True) | |
ubq = ubq.params(slices=5) | |
ubq = ubq.params(conflicts="proceed") | |
for i in range(3): | |
try: | |
r = ubq.execute() | |
return True | |
except Exception as e: | |
es_logger.error("ES updateByQuery exception: " + | |
str(e) + "【Q】:" + str(q.to_dict())) | |
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: | |
continue | |
self.conn() | |
return False | |
def deleteByQuery(self, query, idxnm=""): | |
for i in range(3): | |
try: | |
r = self.es.delete_by_query( | |
index=idxnm if idxnm else self.idxnm, | |
refresh = True, | |
body=Search().query(query).to_dict()) | |
return True | |
except Exception as e: | |
es_logger.error("ES updateByQuery deleteByQuery: " + | |
str(e) + "【Q】:" + str(query.to_dict())) | |
if str(e).find("NotFoundError") > 0: return True | |
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: | |
continue | |
return False | |
def update(self, id, script, routing=None): | |
for i in range(3): | |
try: | |
if not self.version(): | |
r = self.es.update( | |
index=self.idxnm, | |
id=id, | |
body=json.dumps( | |
script, | |
ensure_ascii=False), | |
doc_type="doc", | |
routing=routing, | |
refresh=False) | |
else: | |
r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False), | |
routing=routing, refresh=False) # , doc_type="_doc") | |
return True | |
except Exception as e: | |
es_logger.error( | |
"ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) + | |
json.dumps(script, ensure_ascii=False)) | |
if str(e).find("Timeout") > 0: | |
continue | |
return False | |
def indexExist(self, idxnm): | |
s = Index(idxnm if idxnm else self.idxnm, self.es) | |
for i in range(3): | |
try: | |
return s.exists() | |
except Exception as e: | |
es_logger.error("ES updateByQuery indexExist: " + str(e)) | |
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: | |
continue | |
return False | |
def docExist(self, docid, idxnm=None): | |
for i in range(3): | |
try: | |
return self.es.exists(index=(idxnm if idxnm else self.idxnm), | |
id=docid) | |
except Exception as e: | |
es_logger.error("ES Doc Exist: " + str(e)) | |
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: | |
continue | |
return False | |
def createIdx(self, idxnm, mapping): | |
try: | |
if elasticsearch.__version__[0] < 8: | |
return self.es.indices.create(idxnm, body=mapping) | |
from elasticsearch.client import IndicesClient | |
return IndicesClient(self.es).create(index=idxnm, | |
settings=mapping["settings"], | |
mappings=mapping["mappings"]) | |
except Exception as e: | |
es_logger.error("ES create index error %s ----%s" % (idxnm, str(e))) | |
def deleteIdx(self, idxnm): | |
try: | |
return self.es.indices.delete(idxnm, allow_no_indices=True) | |
except Exception as e: | |
es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e))) | |
def getTotal(self, res): | |
if isinstance(res["hits"]["total"], type({})): | |
return res["hits"]["total"]["value"] | |
return res["hits"]["total"] | |
def getDocIds(self, res): | |
return [d["_id"] for d in res["hits"]["hits"]] | |
def getSource(self, res): | |
rr = [] | |
for d in res["hits"]["hits"]: | |
d["_source"]["id"] = d["_id"] | |
d["_source"]["_score"] = d["_score"] | |
rr.append(d["_source"]) | |
return rr | |
def scrollIter(self, pagesize=100, scroll_time='2m', q={ | |
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}): | |
for _ in range(100): | |
try: | |
page = self.es.search( | |
index=self.idxnm, | |
scroll=scroll_time, | |
size=pagesize, | |
body=q, | |
_source=None | |
) | |
break | |
except Exception as e: | |
es_logger.error("ES scrolling fail. " + str(e)) | |
time.sleep(3) | |
sid = page['_scroll_id'] | |
scroll_size = page['hits']['total']["value"] | |
es_logger.info("[TOTAL]%d" % scroll_size) | |
# Start scrolling | |
while scroll_size > 0: | |
yield page["hits"]["hits"] | |
for _ in range(100): | |
try: | |
page = self.es.scroll(scroll_id=sid, scroll=scroll_time) | |
break | |
except Exception as e: | |
es_logger.error("ES scrolling fail. " + str(e)) | |
time.sleep(3) | |
# Update the scroll ID | |
sid = page['_scroll_id'] | |
# Get the number of results that we returned in the last scroll | |
scroll_size = len(page['hits']['hits']) | |
ELASTICSEARCH = ESConnection() | |