File size: 4,235 Bytes
e17c9f2 a6a5155 e17c9f2 4117eaa a6a5155 e17c9f2 b6336ac a6a5155 4117eaa b6336ac a6a5155 b6336ac a6a5155 b6336ac e17c9f2 a6a5155 e17c9f2 b6336ac e17c9f2 a6a5155 de0c71d a6a5155 de0c71d a6a5155 de0c71d a6a5155 e17c9f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import re
import os
import hashlib
import torch
import struct
from collections import Counter
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
from .header import get_dir
ENV_CHECKED = False
EMBEDDING_CHECKED = False
def check_embedding():
global EMBEDDING_CHECKED
if not EMBEDDING_CHECKED:
# Define the repository and files to download
repo_id = "sentence-transformers/all-MiniLM-L6-v2" # "BAAI/bge-small-en-v1.5"
local_dir = f"./assets/model/{repo_id}"
files_to_download = [
"config.json",
"pytorch_model.bin",
"tokenizer_config.json",
"vocab.txt",
]
# Download each file and save it to the /model/bge directory
for file_name in files_to_download:
if not os.path.exists(os.path.join(local_dir, file_name)):
print(
f"file: {file_name} not exist in {local_dir}, try to download from huggingface ..."
)
hf_hub_download(
repo_id=repo_id,
filename=file_name,
local_dir=local_dir,
)
EMBEDDING_CHECKED = True
def check_env():
global ENV_CHECKED
if not ENV_CHECKED:
env_name_list = [
"NEO4J_URL",
"NEO4J_USERNAME",
"NEO4J_PASSWD",
"MODEL_NAME",
"MODEL_TYPE",
"MODEL_API_KEY",
"BASE_URL",
]
for env_name in env_name_list:
if env_name not in os.environ or os.environ[env_name] == "":
raise ValueError(f"{env_name} is not set...")
ENV_CHECKED = True
class EmbeddingModel:
_instance = None
def __new__(cls, config):
if cls._instance is None:
cls._instance = super(EmbeddingModel, cls).__new__(cls)
device = "cuda" if torch.cuda.is_available() else "cpu"
cls._instance.embedding_model = SentenceTransformer(
model_name_or_path=get_dir(config.DEFAULT.embedding),
device=device,
)
print(f"==== using device {device} ====")
return cls._instance
def get_embedding_model(config):
return EmbeddingModel(config).embedding_model
def generate_hash_id(input_string):
if input_string is None:
return None
sha1_hash = hashlib.sha256(input_string.lower().encode("utf-8")).hexdigest()
binary_hash = bytes.fromhex(sha1_hash)
int64_hash = struct.unpack(">q", binary_hash[:8])[0]
return abs(int64_hash)
def extract_ref_id(text, references):
"""
references: paper["references"]
"""
# 正则表达式模式,用于匹配[数字, 数字]格式
pattern = r"\[\d+(?:,\s*\d+)*\]"
# 提取所有匹配的内容
ref_list = re.findall(pattern, text)
# ref ['[15, 16]', '[5]', '[2, 3, 8]']
combined_ref_list = []
if len(ref_list) > 0:
# 说明是pattern 0
for ref in ref_list:
# 移除方括号并分割数字
numbers = re.findall(r"\d+", ref)
# 将字符串数字转换为整数并加入到列表中
combined_ref_list.extend(map(int, numbers))
# 去重并排序
ref_counts = Counter(combined_ref_list)
ref_counts = dict(sorted(ref_counts.items()))
# 对多个,只保留引用最多的一个
for ref in ref_list:
# 移除方括号并分割数字
numbers = re.findall(r"\d+", ref)
# 找到只引用了一次的
temp_list = []
for num in numbers:
num = int(num)
if ref_counts[num] == 1:
temp_list.append(num)
if len(temp_list) == len(numbers):
temp_list = temp_list[1:]
for num in temp_list:
del ref_counts[num]
hash_id_list = []
for idx in ref_counts.keys():
hash_id_list.append(generate_hash_id(references[idx]))
return hash_id_list
if __name__ == "__main__":
# 示例用法
input_string = "example_string"
hash_id = generate_hash_id(input_string)
print("INT64 Hash ID:", hash_id)
|