File size: 5,855 Bytes
e17c9f2 a6a5155 e17c9f2 4117eaa a6a5155 e17c9f2 b6336ac a6a5155 02069d7 b6336ac a6a5155 02069d7 c8709b2 02069d7 c8709b2 c9fbbef c8709b2 d6249c2 d645672 c8709b2 02069d7 b6336ac a6a5155 b6336ac e17c9f2 a6a5155 e17c9f2 b6336ac 02069d7 b6336ac e17c9f2 a6a5155 02069d7 a6a5155 de0c71d a6a5155 02069d7 de0c71d 02069d7 a6a5155 c8709b2 de0c71d a6a5155 02069d7 a6a5155 02069d7 a6a5155 e17c9f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import re
import os
import hashlib
import torch
import struct
from collections import Counter
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
from .header import get_dir
ENV_CHECKED = False
EMBEDDING_CHECKED = False
def check_embedding(repo_id):
print("=== check embedding model ===")
global EMBEDDING_CHECKED
if not EMBEDDING_CHECKED:
# Define the repository and files to download
local_dir = f"./assets/model/{repo_id}"
if repo_id in [
"sentence-transformers/all-MiniLM-L6-v2",
"BAAI/bge-small-en-v1.5",
"BAAI/llm-embedder",
]:
# repo_id = "sentence-transformers/all-MiniLM-L6-v2"
# repo_id = "BAAI/bge-small-en-v1.5"
files_to_download = [
"config.json",
"pytorch_model.bin",
"tokenizer_config.json",
"vocab.txt",
]
elif repo_id in [
"jinaai/jina-embeddings-v3",
]:
files_to_download = [
"model.safetensors",
"modules.json",
"tokenizer.json",
"config_sentence_transformers.json",
"custom_st.py",
"special_tokens_map.json",
"tokenizer_config.json",
"1_Pooling/config.json",
"config.json",
]
elif repo_id in ["Alibaba-NLP/gte-base-en-v1.5"]:
files_to_download = [
"config.json",
"model.safetensors",
"modules.json",
"tokenizer.json",
"sentence_bert_config.json",
"tokenizer_config.json",
"vocab.txt",
]
# Download each file and save it to the /model/bge directory
for file_name in files_to_download:
if not os.path.exists(os.path.join(local_dir, file_name)):
print(
f"file: {file_name} not exist in {local_dir}, try to download from huggingface ..."
)
hf_hub_download(
repo_id=repo_id,
filename=file_name,
local_dir=local_dir,
)
EMBEDDING_CHECKED = True
def check_env():
global ENV_CHECKED
if not ENV_CHECKED:
env_name_list = [
"NEO4J_URL",
"NEO4J_USERNAME",
"NEO4J_PASSWD",
"MODEL_NAME",
"MODEL_TYPE",
"BASE_URL",
]
for env_name in env_name_list:
if env_name not in os.environ or os.environ[env_name] == "":
raise ValueError(f"{env_name} is not set...")
if os.environ["MODEL_TYPE"] != "Local":
env_name = "MODEL_API_KEY"
if env_name not in os.environ or os.environ[env_name] == "":
raise ValueError(f"{env_name} is not set...")
ENV_CHECKED = True
class EmbeddingModel:
_instance = None
def __new__(cls, config):
if cls._instance is None:
local_dir = f"./assets/model/{config.DEFAULT.embedding}"
cls._instance = super(EmbeddingModel, cls).__new__(cls)
device = "cuda" if torch.cuda.is_available() else "cpu"
cls._instance.embedding_model = SentenceTransformer(
model_name_or_path=get_dir(local_dir),
device=device,
trust_remote_code=True,
)
if "jina-embeddings-v3" in config.DEFAULT.embedding:
cls._instance.embedding_model[0].default_task = config.DEFAULT.embedding_task
print(f"==== using device {device} ====")
return cls._instance
def get_embedding_model(config):
print("=== get embedding model ===")
check_embedding(config.DEFAULT.embedding)
return EmbeddingModel(config).embedding_model
def generate_hash_id(input_string):
if input_string is None:
return None
sha1_hash = hashlib.sha256(input_string.lower().encode("utf-8")).hexdigest()
binary_hash = bytes.fromhex(sha1_hash)
int64_hash = struct.unpack(">q", binary_hash[:8])[0]
return abs(int64_hash)
def extract_ref_id(text, references):
"""
references: paper["references"]
"""
# 正则表达式模式,用于匹配[数字, 数字]格式
pattern = r"\[\d+(?:,\s*\d+)*\]"
# 提取所有匹配的内容
ref_list = re.findall(pattern, text)
# ref ['[15, 16]', '[5]', '[2, 3, 8]']
combined_ref_list = []
if len(ref_list) > 0:
# 说明是pattern 0
for ref in ref_list:
# 移除方括号并分割数字
numbers = re.findall(r"\d+", ref)
# 将字符串数字转换为整数并加入到列表中
combined_ref_list.extend(map(int, numbers))
# 去重并排序
ref_counts = Counter(combined_ref_list)
ref_counts = dict(sorted(ref_counts.items()))
# 对多个,只保留引用最多的一个
for ref in ref_list:
# 移除方括号并分割数字
numbers = re.findall(r"\d+", ref)
# 找到只引用了一次的
temp_list = []
for num in numbers:
num = int(num)
if ref_counts[num] == 1:
temp_list.append(num)
if len(temp_list) == len(numbers):
temp_list = temp_list[1:]
for num in temp_list:
del ref_counts[num]
hash_id_list = []
for idx in ref_counts.keys():
hash_id_list.append(generate_hash_id(references[idx]))
return hash_id_list
if __name__ == "__main__":
# 示例用法
input_string = "example_string"
hash_id = generate_hash_id(input_string)
print("INT64 Hash ID:", hash_id)
|