File size: 3,932 Bytes
11067f9
 
 
 
 
 
 
 
 
 
 
 
 
 
ee48dec
 
 
11067f9
 
 
 
 
 
 
5f36c2e
11067f9
5f36c2e
 
 
 
 
 
aff403e
 
 
 
 
 
 
 
 
 
 
 
5f36c2e
 
 
 
 
 
 
 
 
 
11067f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aff403e
 
11067f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import torch
import requests
import re
import emoji
import nltk
import lxml
import os
from bs4 import BeautifulSoup
from markdown import markdown
from nltk.corpus import stopwords
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from retry import retry
from transformers import pipeline

pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es")

# 确保已下载 nltk 的停用词
nltk.download('stopwords')

# 从环境变量中获取 hf_token
hf_token = os.getenv('HF_TOKEN')


model_id = "BAAI/bge-large-en-v1.5"
feature_extraction_pipeline = pipeline("feature-extraction", model=model_id)

# model_id = "BAAI/bge-large-en-v1.5"
# api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
# headers = {"Authorization": f"Bearer {hf_token}"}

@retry(tries=3, delay=10)
def query(texts):
    # 使用特征提取管道获取特征
    features = feature_extraction_pipeline(texts)
    
    # 将特征降维成二维张量(如果它们不是)
    # 假设 features 是一个列表,每个元素是一个句子的特征
    embeddings = [torch.tensor(f).mean(dim=0) for f in features]
    embeddings = torch.stack(embeddings)

    return embeddings

# def query(texts):
#     response = requests.post(api_url, headers=headers, json={"inputs": texts})
#     if response.status_code == 200:
#         result = response.json()
#         if isinstance(result, list):
#             return result
#         elif 'error' in result:
#             raise RuntimeError("Error from Hugging Face API: " + result['error'])
#     else:
#         raise RuntimeError("Failed to get response from Hugging Face API, status code: " + str(response.status_code))

# 加载嵌入向量数据集
faqs_embeddings_dataset = load_dataset('chenglu/hf-blogs-baai-embeddings')
df = faqs_embeddings_dataset["train"].to_pandas()
embeddings_array = df.T.to_numpy()
dataset_embeddings = torch.from_numpy(embeddings_array).to(torch.float)

# 加载原始数据集
original_dataset = load_dataset("chenglu/hf-blogs")['train']

# 定义英语停用词集
stop_words = set(stopwords.words('english'))

def remove_stopwords(text):
    return ' '.join([word for word in text.split() if word.lower() not in stop_words])

def clean_content(content):
    content = re.sub(r"(```.*?```|`.*?`)", "", content, flags=re.DOTALL)
    content = BeautifulSoup(content, "html.parser").get_text()
    content = emoji.replace_emoji(content, replace='')
    content = re.sub(r"[^a-zA-Z\s]", "", content)
    content = re.sub(r"http\S+|www\S+|https\S+", '', content, flags=re.MULTILINE)
    content = markdown(content)
    content = ''.join(BeautifulSoup(content, 'lxml').findAll(text=True))
    content = re.sub(r'\s+', ' ', content)
    return content

def get_tags_for_local(dataset, local_value):
    entry = next((item for item in dataset if item['local'] == local_value), None)
    if entry:
        return entry['tags']
    else:
        return None

def gradio_query_interface(input_text):
    cleaned_text = clean_content(input_text)
    no_stopwords_text = remove_stopwords(cleaned_text)
    new_embedding = query(no_stopwords_text)
    # new_embedding = feature_extraction_pipeline(input_text)
    query_embeddings = torch.FloatTensor(new_embedding)
    hits = util.semantic_search(query_embeddings, dataset_embeddings, top_k=5)
    if all(hit['score'] < 0.6 for hit in hits[0]):
        return "Content Not related"
    else:
        highest_score_result = max(hits[0], key=lambda x: x['score'])
        highest_score_corpus_id = highest_score_result['corpus_id']
        local = df.columns[highest_score_corpus_id]
        recommended_tags = get_tags_for_local(original_dataset, local)
        return f"Recommended category tags: {recommended_tags}"

iface = gr.Interface(
    fn=gradio_query_interface, 
    inputs="text", 
    outputs="label"
)

iface.launch()