File size: 4,235 Bytes
e17c9f2
 
 
a6a5155
e17c9f2
 
4117eaa
a6a5155
 
e17c9f2
b6336ac
 
 
a6a5155
4117eaa
b6336ac
 
 
 
a6a5155
b6336ac
 
 
 
 
 
 
 
a6a5155
 
 
 
 
 
 
 
 
b6336ac
e17c9f2
a6a5155
e17c9f2
b6336ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e17c9f2
 
a6a5155
 
 
 
 
 
de0c71d
a6a5155
 
de0c71d
a6a5155
de0c71d
a6a5155
 
 
 
 
 
e17c9f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import re
import os
import hashlib
import torch
import struct
from collections import Counter
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
from .header import get_dir

ENV_CHECKED = False
EMBEDDING_CHECKED = False


def check_embedding():
    global EMBEDDING_CHECKED
    if not EMBEDDING_CHECKED:
        # Define the repository and files to download
        repo_id = "sentence-transformers/all-MiniLM-L6-v2"  # "BAAI/bge-small-en-v1.5"
        local_dir = f"./assets/model/{repo_id}"
        files_to_download = [
            "config.json",
            "pytorch_model.bin",
            "tokenizer_config.json",
            "vocab.txt",
        ]
        # Download each file and save it to the /model/bge directory
        for file_name in files_to_download:
            if not os.path.exists(os.path.join(local_dir, file_name)):
                print(
                    f"file: {file_name} not exist in {local_dir}, try to download from huggingface ..."
                )
                hf_hub_download(
                    repo_id=repo_id,
                    filename=file_name,
                    local_dir=local_dir,
                )
        EMBEDDING_CHECKED = True


def check_env():
    global ENV_CHECKED
    if not ENV_CHECKED:
        env_name_list = [
            "NEO4J_URL",
            "NEO4J_USERNAME",
            "NEO4J_PASSWD",
            "MODEL_NAME",
            "MODEL_TYPE",
            "MODEL_API_KEY",
            "BASE_URL",
        ]
        for env_name in env_name_list:
            if env_name not in os.environ or os.environ[env_name] == "":
                raise ValueError(f"{env_name} is not set...")
        ENV_CHECKED = True


class EmbeddingModel:
    _instance = None

    def __new__(cls, config):
        if cls._instance is None:
            cls._instance = super(EmbeddingModel, cls).__new__(cls)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            cls._instance.embedding_model = SentenceTransformer(
                model_name_or_path=get_dir(config.DEFAULT.embedding),
                device=device,
            )
            print(f"==== using device {device} ====")
        return cls._instance

def get_embedding_model(config):
    return EmbeddingModel(config).embedding_model


def generate_hash_id(input_string):
    if input_string is None:
        return None
    sha1_hash = hashlib.sha256(input_string.lower().encode("utf-8")).hexdigest()
    binary_hash = bytes.fromhex(sha1_hash)
    int64_hash = struct.unpack(">q", binary_hash[:8])[0]
    return abs(int64_hash)


def extract_ref_id(text, references):
    """
    references: paper["references"]
    """
    # 正则表达式模式,用于匹配[数字, 数字]格式
    pattern = r"\[\d+(?:,\s*\d+)*\]"
    # 提取所有匹配的内容
    ref_list = re.findall(pattern, text)
    # ref ['[15, 16]', '[5]', '[2, 3, 8]']
    combined_ref_list = []
    if len(ref_list) > 0:
        # 说明是pattern 0
        for ref in ref_list:
            # 移除方括号并分割数字
            numbers = re.findall(r"\d+", ref)
            # 将字符串数字转换为整数并加入到列表中
            combined_ref_list.extend(map(int, numbers))
        # 去重并排序
        ref_counts = Counter(combined_ref_list)
        ref_counts = dict(sorted(ref_counts.items()))
        # 对多个,只保留引用最多的一个
        for ref in ref_list:
            # 移除方括号并分割数字
            numbers = re.findall(r"\d+", ref)
            # 找到只引用了一次的
            temp_list = []
            for num in numbers:
                num = int(num)
                if ref_counts[num] == 1:
                    temp_list.append(num)
            if len(temp_list) == len(numbers):
                temp_list = temp_list[1:]
            for num in temp_list:
                del ref_counts[num]
    hash_id_list = []
    for idx in ref_counts.keys():
        hash_id_list.append(generate_hash_id(references[idx]))
    return hash_id_list


if __name__ == "__main__":
    # 示例用法
    input_string = "example_string"
    hash_id = generate_hash_id(input_string)
    print("INT64 Hash ID:", hash_id)