hf-blog-tags / app.py
chenglu's picture
Update app.py
ee48dec
raw
history blame
3.55 kB
import gradio as gr
import torch
import requests
import re
import emoji
import nltk
import lxml
import os
from bs4 import BeautifulSoup
from markdown import markdown
from nltk.corpus import stopwords
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from retry import retry
from transformers import pipeline
pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es")
# 确保已下载 nltk 的停用词
nltk.download('stopwords')
# 从环境变量中获取 hf_token
hf_token = os.getenv('HF_TOKEN')
model_id = "BAAI/bge-large-en-v1.5"
feature_extraction_pipeline = pipeline("feature-extraction", model=model_id)
# model_id = "BAAI/bge-large-en-v1.5"
# api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
# headers = {"Authorization": f"Bearer {hf_token}"}
# @retry(tries=3, delay=10)
# def query(texts):
# response = requests.post(api_url, headers=headers, json={"inputs": texts})
# if response.status_code == 200:
# result = response.json()
# if isinstance(result, list):
# return result
# elif 'error' in result:
# raise RuntimeError("Error from Hugging Face API: " + result['error'])
# else:
# raise RuntimeError("Failed to get response from Hugging Face API, status code: " + str(response.status_code))
# 加载嵌入向量数据集
faqs_embeddings_dataset = load_dataset('chenglu/hf-blogs-baai-embeddings')
df = faqs_embeddings_dataset["train"].to_pandas()
embeddings_array = df.T.to_numpy()
dataset_embeddings = torch.from_numpy(embeddings_array).to(torch.float)
# 加载原始数据集
original_dataset = load_dataset("chenglu/hf-blogs")['train']
# 定义英语停用词集
stop_words = set(stopwords.words('english'))
def remove_stopwords(text):
return ' '.join([word for word in text.split() if word.lower() not in stop_words])
def clean_content(content):
content = re.sub(r"(```.*?```|`.*?`)", "", content, flags=re.DOTALL)
content = BeautifulSoup(content, "html.parser").get_text()
content = emoji.replace_emoji(content, replace='')
content = re.sub(r"[^a-zA-Z\s]", "", content)
content = re.sub(r"http\S+|www\S+|https\S+", '', content, flags=re.MULTILINE)
content = markdown(content)
content = ''.join(BeautifulSoup(content, 'lxml').findAll(text=True))
content = re.sub(r'\s+', ' ', content)
return content
def get_tags_for_local(dataset, local_value):
entry = next((item for item in dataset if item['local'] == local_value), None)
if entry:
return entry['tags']
else:
return None
def gradio_query_interface(input_text):
cleaned_text = clean_content(input_text)
no_stopwords_text = remove_stopwords(cleaned_text)
# new_embedding = query(no_stopwords_text)
new_embedding = feature_extraction_pipeline(input_text)
query_embeddings = torch.FloatTensor(new_embedding)
hits = util.semantic_search(query_embeddings, dataset_embeddings, top_k=5)
if all(hit['score'] < 0.6 for hit in hits[0]):
return "Content Not related"
else:
highest_score_result = max(hits[0], key=lambda x: x['score'])
highest_score_corpus_id = highest_score_result['corpus_id']
local = df.columns[highest_score_corpus_id]
recommended_tags = get_tags_for_local(original_dataset, local)
return f"Recommended category tags: {recommended_tags}"
iface = gr.Interface(
fn=gradio_query_interface,
inputs="text",
outputs="label"
)
iface.launch()