|
import re |
|
import os |
|
import hashlib |
|
import torch |
|
import struct |
|
from collections import Counter |
|
from huggingface_hub import hf_hub_download |
|
from sentence_transformers import SentenceTransformer |
|
from .header import get_dir |
|
|
|
ENV_CHECKED = False |
|
EMBEDDING_CHECKED = False |
|
|
|
|
|
def check_embedding(repo_id): |
|
print("=== check embedding model ===") |
|
global EMBEDDING_CHECKED |
|
if not EMBEDDING_CHECKED: |
|
|
|
local_dir = f"./assets/model/{repo_id}" |
|
if repo_id in [ |
|
"sentence-transformers/all-MiniLM-L6-v2", |
|
"BAAI/bge-small-en-v1.5", |
|
"BAAI/llm-embedder", |
|
]: |
|
|
|
|
|
files_to_download = [ |
|
"config.json", |
|
"pytorch_model.bin", |
|
"tokenizer_config.json", |
|
"vocab.txt", |
|
] |
|
elif repo_id in [ |
|
"jinaai/jina-embeddings-v3", |
|
]: |
|
files_to_download = [ |
|
"model.safetensors", |
|
"modules.json", |
|
"tokenizer.json", |
|
"config_sentence_transformers.json", |
|
"custom_st.py", |
|
"special_tokens_map.json", |
|
"tokenizer_config.json", |
|
"1_Pooling/config.json", |
|
"config.json", |
|
] |
|
elif repo_id in ["Alibaba-NLP/gte-base-en-v1.5"]: |
|
files_to_download = [ |
|
"config.json", |
|
"model.safetensors", |
|
"modules.json", |
|
"tokenizer.json", |
|
"sentence_bert_config.json", |
|
"tokenizer_config.json", |
|
"vocab.txt", |
|
] |
|
|
|
for file_name in files_to_download: |
|
if not os.path.exists(os.path.join(local_dir, file_name)): |
|
print( |
|
f"file: {file_name} not exist in {local_dir}, try to download from huggingface ..." |
|
) |
|
hf_hub_download( |
|
repo_id=repo_id, |
|
filename=file_name, |
|
local_dir=local_dir, |
|
) |
|
EMBEDDING_CHECKED = True |
|
|
|
|
|
def check_env(): |
|
global ENV_CHECKED |
|
if not ENV_CHECKED: |
|
env_name_list = [ |
|
"NEO4J_URL", |
|
"NEO4J_USERNAME", |
|
"NEO4J_PASSWD", |
|
"MODEL_NAME", |
|
"MODEL_TYPE", |
|
"BASE_URL", |
|
] |
|
for env_name in env_name_list: |
|
if env_name not in os.environ or os.environ[env_name] == "": |
|
raise ValueError(f"{env_name} is not set...") |
|
if os.environ["MODEL_TYPE"] != "Local": |
|
env_name = "MODEL_API_KEY" |
|
if env_name not in os.environ or os.environ[env_name] == "": |
|
raise ValueError(f"{env_name} is not set...") |
|
ENV_CHECKED = True |
|
|
|
|
|
class EmbeddingModel: |
|
_instance = None |
|
|
|
def __new__(cls, config): |
|
if cls._instance is None: |
|
local_dir = f"./assets/model/{config.DEFAULT.embedding}" |
|
cls._instance = super(EmbeddingModel, cls).__new__(cls) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
cls._instance.embedding_model = SentenceTransformer( |
|
model_name_or_path=get_dir(local_dir), |
|
device=device, |
|
trust_remote_code=True, |
|
) |
|
if "jina-embeddings-v3" in config.DEFAULT.embedding: |
|
cls._instance.embedding_model[0].default_task = config.DEFAULT.embedding_task |
|
print(f"==== using device {device} ====") |
|
return cls._instance |
|
|
|
|
|
def get_embedding_model(config): |
|
print("=== get embedding model ===") |
|
check_embedding(config.DEFAULT.embedding) |
|
return EmbeddingModel(config).embedding_model |
|
|
|
|
|
def generate_hash_id(input_string): |
|
if input_string is None: |
|
return None |
|
sha1_hash = hashlib.sha256(input_string.lower().encode("utf-8")).hexdigest() |
|
binary_hash = bytes.fromhex(sha1_hash) |
|
int64_hash = struct.unpack(">q", binary_hash[:8])[0] |
|
return abs(int64_hash) |
|
|
|
|
|
def extract_ref_id(text, references): |
|
""" |
|
references: paper["references"] |
|
""" |
|
|
|
pattern = r"\[\d+(?:,\s*\d+)*\]" |
|
|
|
ref_list = re.findall(pattern, text) |
|
|
|
combined_ref_list = [] |
|
if len(ref_list) > 0: |
|
|
|
for ref in ref_list: |
|
|
|
numbers = re.findall(r"\d+", ref) |
|
|
|
combined_ref_list.extend(map(int, numbers)) |
|
|
|
ref_counts = Counter(combined_ref_list) |
|
ref_counts = dict(sorted(ref_counts.items())) |
|
|
|
for ref in ref_list: |
|
|
|
numbers = re.findall(r"\d+", ref) |
|
|
|
temp_list = [] |
|
for num in numbers: |
|
num = int(num) |
|
if ref_counts[num] == 1: |
|
temp_list.append(num) |
|
if len(temp_list) == len(numbers): |
|
temp_list = temp_list[1:] |
|
for num in temp_list: |
|
del ref_counts[num] |
|
hash_id_list = [] |
|
for idx in ref_counts.keys(): |
|
hash_id_list.append(generate_hash_id(references[idx])) |
|
return hash_id_list |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
input_string = "example_string" |
|
hash_id = generate_hash_id(input_string) |
|
print("INT64 Hash ID:", hash_id) |
|
|