import re import json import time import copy import elasticsearch from elastic_transport import ConnectionTimeout from elasticsearch import Elasticsearch from elasticsearch_dsl import UpdateByQuery, Search, Index from rag.settings import es_logger from rag import settings from rag.utils import singleton es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__)) @singleton class ESConnection: def __init__(self): self.info = {} self.conn() self.idxnm = settings.ES.get("index_name", "") if not self.es.ping(): raise Exception("Can't connect to ES cluster") def conn(self): for _ in range(10): try: self.es = Elasticsearch( settings.ES["hosts"].split(","), basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None, verify_certs=False, timeout=600 ) if self.es: self.info = self.es.info() es_logger.info("Connect to es.") break except Exception as e: es_logger.error("Fail to connect to es: " + str(e)) time.sleep(1) def version(self): v = self.info.get("version", {"number": "5.6"}) v = v["number"].split(".")[0] return int(v) >= 7 def health(self): return dict(self.es.cluster.health()) def upsert(self, df, idxnm=""): res = [] for d in df: id = d["id"] del d["id"] d = {"doc": d, "doc_as_upsert": "true"} T = False for _ in range(10): try: if not self.version(): r = self.es.update( index=( self.idxnm if not idxnm else idxnm), body=d, id=id, doc_type="doc", refresh=True, retry_on_conflict=100) else: r = self.es.update( index=( self.idxnm if not idxnm else idxnm), body=d, id=id, refresh=True, retry_on_conflict=100) es_logger.info("Successfully upsert: %s" % id) T = True break except Exception as e: es_logger.warning("Fail to index: " + json.dumps(d, ensure_ascii=False) + str(e)) if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): time.sleep(3) continue self.conn() T = False if not T: res.append(d) es_logger.error( "Fail to index: " + re.sub( "[\r\n]", "", json.dumps( d, ensure_ascii=False))) d["id"] = id d["_index"] = self.idxnm if not res: return True return False def bulk(self, df, idx_nm=None): ids, acts = {}, [] for d in df: id = d["id"] if "id" in d else d["_id"] ids[id] = copy.deepcopy(d) ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm if "id" in d: del d["id"] if "_id" in d: del d["_id"] acts.append( {"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100}) acts.append({"doc": d, "doc_as_upsert": "true"}) res = [] for _ in range(100): try: if elasticsearch.__version__[0] < 8: r = self.es.bulk( index=( self.idxnm if not idx_nm else idx_nm), body=acts, refresh=False, timeout="600s") else: r = self.es.bulk(index=(self.idxnm if not idx_nm else idx_nm), operations=acts, refresh=False, timeout="600s") if re.search(r"False", str(r["errors"]), re.IGNORECASE): return res for it in r["items"]: if "error" in it["update"]: res.append(str(it["update"]["_id"]) + ":" + str(it["update"]["error"])) return res except Exception as e: es_logger.warn("Fail to bulk: " + str(e)) if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): time.sleep(3) continue self.conn() return res def bulk4script(self, df): ids, acts = {}, [] for d in df: id = d["id"] ids[id] = copy.deepcopy(d["raw"]) acts.append({"update": {"_id": id, "_index": self.idxnm}}) acts.append(d["script"]) es_logger.info("bulk upsert: %s" % id) res = [] for _ in range(10): try: if not self.version(): r = self.es.bulk( index=self.idxnm, body=acts, refresh=False, timeout="600s", doc_type="doc") else: r = self.es.bulk( index=self.idxnm, body=acts, refresh=False, timeout="600s") if re.search(r"False", str(r["errors"]), re.IGNORECASE): return res for it in r["items"]: if "error" in it["update"]: res.append(str(it["update"]["_id"])) return res except Exception as e: es_logger.warning("Fail to bulk: " + str(e)) if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): time.sleep(3) continue self.conn() return res def rm(self, d): for _ in range(10): try: if not self.version(): r = self.es.delete( index=self.idxnm, id=d["id"], doc_type="doc", refresh=True) else: r = self.es.delete( index=self.idxnm, id=d["id"], refresh=True, doc_type="_doc") es_logger.info("Remove %s" % d["id"]) return True except Exception as e: es_logger.warn("Fail to delete: " + str(d) + str(e)) if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): time.sleep(3) continue if re.search(r"(not_found)", str(e), re.IGNORECASE): return True self.conn() es_logger.error("Fail to delete: " + str(d)) return False def search(self, q, idxnm=None, src=False, timeout="2s"): if not isinstance(q, dict): q = Search().query(q).to_dict() for i in range(3): try: res = self.es.search(index=(self.idxnm if not idxnm else idxnm), body=q, timeout=timeout, # search_type="dfs_query_then_fetch", track_total_hits=True, _source=src) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") return res except Exception as e: es_logger.error( "ES search exception: " + str(e) + "【Q】:" + str(q)) if str(e).find("Timeout") > 0: continue raise e es_logger.error("ES search timeout for 3 times!") raise Exception("ES search timeout.") def sql(self, sql, fetch_size=128, format="json", timeout="2s"): for i in range(3): try: res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout) return res except ConnectionTimeout as e: es_logger.error("Timeout【Q】:" + sql) continue except Exception as e: raise e es_logger.error("ES search timeout for 3 times!") raise ConnectionTimeout() def get(self, doc_id, idxnm=None): for i in range(3): try: res = self.es.get(index=(self.idxnm if not idxnm else idxnm), id=doc_id) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") return res except Exception as e: es_logger.error( "ES get exception: " + str(e) + "【Q】:" + doc_id) if str(e).find("Timeout") > 0: continue raise e es_logger.error("ES search timeout for 3 times!") raise Exception("ES search timeout.") def updateByQuery(self, q, d): ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q) scripts = "" for k, v in d.items(): scripts += "ctx._source.%s = params.%s;" % (str(k), str(k)) ubq = ubq.script(source=scripts, params=d) ubq = ubq.params(refresh=False) ubq = ubq.params(slices=5) ubq = ubq.params(conflicts="proceed") for i in range(3): try: r = ubq.execute() return True except Exception as e: es_logger.error("ES updateByQuery exception: " + str(e) + "【Q】:" + str(q.to_dict())) if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue self.conn() return False def updateScriptByQuery(self, q, scripts, idxnm=None): ubq = UpdateByQuery( index=self.idxnm if not idxnm else idxnm).using( self.es).query(q) ubq = ubq.script(source=scripts) ubq = ubq.params(refresh=True) ubq = ubq.params(slices=5) ubq = ubq.params(conflicts="proceed") for i in range(3): try: r = ubq.execute() return True except Exception as e: es_logger.error("ES updateByQuery exception: " + str(e) + "【Q】:" + str(q.to_dict())) if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue self.conn() return False def deleteByQuery(self, query, idxnm=""): for i in range(3): try: r = self.es.delete_by_query( index=idxnm if idxnm else self.idxnm, refresh = True, body=Search().query(query).to_dict()) return True except Exception as e: es_logger.error("ES updateByQuery deleteByQuery: " + str(e) + "【Q】:" + str(query.to_dict())) if str(e).find("NotFoundError") > 0: return True if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue return False def update(self, id, script, routing=None): for i in range(3): try: if not self.version(): r = self.es.update( index=self.idxnm, id=id, body=json.dumps( script, ensure_ascii=False), doc_type="doc", routing=routing, refresh=False) else: r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False), routing=routing, refresh=False) # , doc_type="_doc") return True except Exception as e: es_logger.error( "ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) + json.dumps(script, ensure_ascii=False)) if str(e).find("Timeout") > 0: continue return False def indexExist(self, idxnm): s = Index(idxnm if idxnm else self.idxnm, self.es) for i in range(3): try: return s.exists() except Exception as e: es_logger.error("ES updateByQuery indexExist: " + str(e)) if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue return False def docExist(self, docid, idxnm=None): for i in range(3): try: return self.es.exists(index=(idxnm if idxnm else self.idxnm), id=docid) except Exception as e: es_logger.error("ES Doc Exist: " + str(e)) if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue return False def createIdx(self, idxnm, mapping): try: if elasticsearch.__version__[0] < 8: return self.es.indices.create(idxnm, body=mapping) from elasticsearch.client import IndicesClient return IndicesClient(self.es).create(index=idxnm, settings=mapping["settings"], mappings=mapping["mappings"]) except Exception as e: es_logger.error("ES create index error %s ----%s" % (idxnm, str(e))) def deleteIdx(self, idxnm): try: return self.es.indices.delete(idxnm, allow_no_indices=True) except Exception as e: es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e))) def getTotal(self, res): if isinstance(res["hits"]["total"], type({})): return res["hits"]["total"]["value"] return res["hits"]["total"] def getDocIds(self, res): return [d["_id"] for d in res["hits"]["hits"]] def getSource(self, res): rr = [] for d in res["hits"]["hits"]: d["_source"]["id"] = d["_id"] d["_source"]["_score"] = d["_score"] rr.append(d["_source"]) return rr def scrollIter(self, pagesize=100, scroll_time='2m', q={ "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}): for _ in range(100): try: page = self.es.search( index=self.idxnm, scroll=scroll_time, size=pagesize, body=q, _source=None ) break except Exception as e: es_logger.error("ES scrolling fail. " + str(e)) time.sleep(3) sid = page['_scroll_id'] scroll_size = page['hits']['total']["value"] es_logger.info("[TOTAL]%d" % scroll_size) # Start scrolling while scroll_size > 0: yield page["hits"]["hits"] for _ in range(100): try: page = self.es.scroll(scroll_id=sid, scroll=scroll_time) break except Exception as e: es_logger.error("ES scrolling fail. " + str(e)) time.sleep(3) # Update the scroll ID sid = page['_scroll_id'] # Get the number of results that we returned in the last scroll scroll_size = len(page['hits']['hits']) ELASTICSEARCH = ESConnection()