Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import requests | |
import re | |
import emoji | |
import nltk | |
import lxml | |
import os | |
from bs4 import BeautifulSoup | |
from markdown import markdown | |
from nltk.corpus import stopwords | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer, util | |
from retry import retry | |
# 确保已下载 nltk 的停用词 | |
nltk.download('stopwords') | |
# 从环境变量中获取 hf_token | |
hf_token = os.getenv('HF_TOKEN') | |
model_id = "BAAI/bge-large-en-v1.5" | |
feature_extraction_pipeline = pipeline("feature-extraction", model=model_id) | |
# model_id = "BAAI/bge-large-en-v1.5" | |
# api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" | |
# headers = {"Authorization": f"Bearer {hf_token}"} | |
# @retry(tries=3, delay=10) | |
# def query(texts): | |
# response = requests.post(api_url, headers=headers, json={"inputs": texts}) | |
# if response.status_code == 200: | |
# result = response.json() | |
# if isinstance(result, list): | |
# return result | |
# elif 'error' in result: | |
# raise RuntimeError("Error from Hugging Face API: " + result['error']) | |
# else: | |
# raise RuntimeError("Failed to get response from Hugging Face API, status code: " + str(response.status_code)) | |
# 加载嵌入向量数据集 | |
faqs_embeddings_dataset = load_dataset('chenglu/hf-blogs-baai-embeddings') | |
df = faqs_embeddings_dataset["train"].to_pandas() | |
embeddings_array = df.T.to_numpy() | |
dataset_embeddings = torch.from_numpy(embeddings_array).to(torch.float) | |
# 加载原始数据集 | |
original_dataset = load_dataset("chenglu/hf-blogs")['train'] | |
# 定义英语停用词集 | |
stop_words = set(stopwords.words('english')) | |
def remove_stopwords(text): | |
return ' '.join([word for word in text.split() if word.lower() not in stop_words]) | |
def clean_content(content): | |
content = re.sub(r"(```.*?```|`.*?`)", "", content, flags=re.DOTALL) | |
content = BeautifulSoup(content, "html.parser").get_text() | |
content = emoji.replace_emoji(content, replace='') | |
content = re.sub(r"[^a-zA-Z\s]", "", content) | |
content = re.sub(r"http\S+|www\S+|https\S+", '', content, flags=re.MULTILINE) | |
content = markdown(content) | |
content = ''.join(BeautifulSoup(content, 'lxml').findAll(text=True)) | |
content = re.sub(r'\s+', ' ', content) | |
return content | |
def get_tags_for_local(dataset, local_value): | |
entry = next((item for item in dataset if item['local'] == local_value), None) | |
if entry: | |
return entry['tags'] | |
else: | |
return None | |
def gradio_query_interface(input_text): | |
cleaned_text = clean_content(input_text) | |
no_stopwords_text = remove_stopwords(cleaned_text) | |
# new_embedding = query(no_stopwords_text) | |
new_embedding = feature_extraction_pipeline(input_text) | |
query_embeddings = torch.FloatTensor(new_embedding) | |
hits = util.semantic_search(query_embeddings, dataset_embeddings, top_k=5) | |
if all(hit['score'] < 0.6 for hit in hits[0]): | |
return "Content Not related" | |
else: | |
highest_score_result = max(hits[0], key=lambda x: x['score']) | |
highest_score_corpus_id = highest_score_result['corpus_id'] | |
local = df.columns[highest_score_corpus_id] | |
recommended_tags = get_tags_for_local(original_dataset, local) | |
return f"Recommended category tags: {recommended_tags}" | |
iface = gr.Interface( | |
fn=gradio_query_interface, | |
inputs="text", | |
outputs="label" | |
) | |
iface.launch() | |