Spaces:
Running
Running
import streamlit as st | |
import requests | |
from bs4 import BeautifulSoup | |
import pandas as pd | |
import torch | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer, util | |
import concurrent.futures | |
import time | |
import sys | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import AutoTokenizer, AutoModel | |
import numpy as np | |
from scipy import stats | |
from PyDictionary import PyDictionary | |
import matplotlib.pyplot as plt | |
from scipy import stats | |
import litellm | |
import re | |
import sentencepiece | |
import random | |
from global_vars import t, translations | |
from app import Plugin | |
from embeddings_ft import finetune as finetune_embeddings | |
from bart_ft import finetune as finetune_bart | |
from webrankings_helper import * | |
from plugins.scansite import ScansitePlugin | |
#from data import reference_data_valid, reference_data_rejected | |
#reference_data = reference_data_valid + reference_data_rejected | |
# Ajout des traductions spécifiques à ce plugin | |
translations["en"].update({ | |
"webrankings_title": "Comparative os sorter", | |
"clear_memory": "Clear Memory", | |
"enter_topic": "Enter the topic you're interested in (e.g. longevity):", | |
"use_keyword_expansion": "Use keyword expansion", | |
"test_content": "Also test link content in addition to titles", | |
"select_llm_models": "Select LLM models to use", | |
"select_zero_shot_models": "Select zero-shot models to use", | |
"select_embedding_models": "Select embedding models to use", | |
"analyze_button": "Analyze", | |
"loading_models": "Loading models and analyzing links...", | |
"expanded_keywords": "Expanded keywords:", | |
"analysis_completed": "Analysis completed in {:.2f} seconds", | |
"evaluation_results": "Evaluation results with optimal thresholds:", | |
"summary_table": "Summary table of scores", | |
"optimal_thresholds": "Optimal thresholds:", | |
"spearman_comparison": "Comparison of Spearman correlations", | |
"methods": "Methods", | |
"spearman_correlation": "Spearman correlation coefficient", | |
"results_for": "Results for {}", | |
"device_info": "Device used for inference: {}", | |
"finetune_bart_title": "BART Fine-tuning Interface", | |
"finetune_embeddings_title": "Embeddings Fine-tuning Interface", | |
}) | |
translations["fr"].update({ | |
"webrankings_title": "Analyseur de classeurs", | |
"clear_memory": "Vider la mémoire", | |
"enter_topic": "Entrez le sujet qui vous intéresse (ex: longévité):", | |
"use_keyword_expansion": "Utiliser l'expansion des mots-clés", | |
"test_content": "Tester aussi le contenu des liens en plus des titres", | |
"select_llm_models": "Sélectionnez les modèles LLM à utiliser", | |
"select_zero_shot_models": "Sélectionnez les modèles zero-shot à utiliser", | |
"select_embedding_models": "Sélectionnez les modèles d'embedding à utiliser", | |
"analyze_button": "Analyser", | |
"loading_models": "Chargement des modèles et analyse des liens...", | |
"expanded_keywords": "Mots-clés étendus :", | |
"analysis_completed": "Analyse terminée en {:.2f} secondes", | |
"evaluation_results": "Résultats d'évaluation avec les seuils optimaux :", | |
"summary_table": "Tableau récapitulatif des scores", | |
"optimal_thresholds": "Seuils optimaux :", | |
"spearman_comparison": "Comparaison des corrélations de Spearman", | |
"methods": "Méthodes", | |
"spearman_correlation": "Coefficient de corrélation de Spearman", | |
"results_for": "Résultats pour {}", | |
"device_info": "Dispositif utilisé pour l'inférence : {}", | |
"finetune_bart_title": "Interface de Fine-tuning BART", | |
"finetune_embeddings_title": "Interface de Fine-tuning des Embeddings", | |
}) | |
# Liste des modèles LLM | |
llm_models = [] #["ollama/llama3", "ollama/llama3.1", "ollama/qwen2", "ollama/phi3:medium-128k", "ollama/openhermes"] | |
# Liste des modèles zero-shot | |
zero_shot_models = [ | |
("facebook/bart-large-mnli", "facebook/bart-large-mnli"), | |
("bart-large-ft", "./bart-large-ft") | |
#("cross-encoder/nli-deberta-v3-base", "cross-encoder/nli-deberta-v3-base") | |
] | |
# Liste des modèles d'embedding | |
embedding_models = [ | |
("paraphrase-MiniLM-L6-v2", "paraphrase-MiniLM-L6-v2"), | |
("all-MiniLM-L6-v2", "all-MiniLM-L6-v2"), | |
("nomic-embed-text-v1", "nomic-ai/nomic-embed-text-v1"), | |
("embeddings-ft", "./embeddings-ft") | |
] | |
class WebrankingsPlugin(Plugin): | |
def __init__(self, name, plugin_manager): | |
super().__init__(name, plugin_manager) | |
self.scansite_plugin = ScansitePlugin('scansite', plugin_manager) | |
def get_tabs(self): | |
return [ | |
{"name": t("webrankings_title"), "plugin": "webrankings"} | |
] | |
def run(self, config): | |
tab1, tab2, tab3 = st.tabs([t("webrankings_title"), t("finetune_bart_title"), t("finetune_embeddings_title")]) | |
reference_data_valid, reference_data_rejected = self.scansite_plugin.get_reference_data() | |
reference_data = reference_data_valid + [(url, title, 0) for url, title in reference_data_rejected] | |
with tab1: | |
st.title(t("webrankings_title")) | |
if st.button(t("clear_memory")): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
clear_globals() | |
reset_cuda_context() | |
topic = st.text_input(t("enter_topic"), value="longevity, health, healthspan, lifespan") | |
use_synonyms = st.checkbox(t("use_keyword_expansion"), value=False) | |
check_content = st.checkbox(t("test_content"), value=False) | |
selected_llm_models = st.multiselect(t("select_llm_models"), llm_models, default=llm_models) | |
selected_zero_shot_models = st.multiselect(t("select_zero_shot_models"), [m[0] for m in zero_shot_models], default=[m[0] for m in zero_shot_models]) | |
selected_embedding_models = st.multiselect(t("select_embedding_models"), [m[0] for m in embedding_models], default=[m[0] for m in embedding_models]) | |
if st.button(t("analyze_button")): | |
with st.spinner(t("loading_models")): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Préparation des modèles | |
zero_shot_classifiers = {name: pipeline("zero-shot-classification", model=model, device=device) | |
for name, model in zero_shot_models if name in selected_zero_shot_models} | |
embedding_models_dict = {} | |
for name, model in embedding_models: | |
import os | |
if name == "embeddings-ft": | |
if os.path.exists('./embeddings-ft'): | |
embedding_models_dict[name] = SentenceTransformer('./embeddings-ft', trust_remote_code=True).to(device) | |
else: | |
embedding_models_dict[name] = SentenceTransformer(model, trust_remote_code=True).to(device) | |
bert_models = [AutoModel.from_pretrained('bert-base-uncased').to(device)] | |
tfidf_objects = [TfidfVectorizer()] | |
#release_vram(zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects) | |
# Expansion des mots-clés (utilisant le premier modèle LLM sélectionné) | |
if use_synonyms and selected_llm_models: | |
expanded_query = [] | |
for word in topic.split(): | |
expanded_query.extend(expand_keywords_llm(word, llm_model=selected_llm_models[0])) | |
expanded_query = " ".join(expanded_query) | |
st.write("Mots-clés étendus :", expanded_query) | |
else: | |
expanded_query = topic | |
start_time = time.time() | |
# Analyse pour chaque lien | |
results = [] | |
for title, link,note in reference_data: | |
result = analyze_link( | |
title, link, topic, zero_shot_classifiers, embedding_models_dict, | |
expanded_query, selected_llm_models, check_content | |
) | |
if result is not None: | |
results.append(result) | |
end_time = time.time() | |
# Libération de la mémoire VRAM et des autres ressources | |
release_vram(zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects) | |
# Création du DataFrame avec tous les résultats | |
df = pd.DataFrame(results) | |
print(f"Analyse terminée en {end_time - start_time:.2f} secondes") | |
st.success(t("analysis_completed").format(end_time - start_time)) | |
# Évaluation et affichage des résultats | |
evaluation_results = {} | |
optimal_thresholds = {} | |
for column in df.columns: | |
if column != "Titre": | |
method_scores = df.set_index("Titre")[column].to_dict() | |
optimal_threshold = find_optimal_threshold( | |
[item[0] for item in reference_data_valid], | |
[item[0] for item in reference_data_rejected], | |
method_scores | |
) | |
optimal_thresholds[column] = optimal_threshold | |
evaluation_results[column] = evaluate_ranking( | |
[item[0] for item in reference_data_valid], | |
[item[0] for item in reference_data_rejected], | |
method_scores, | |
optimal_threshold, False | |
) | |
# Affichage des résultats | |
st.write(t("evaluation_results")) | |
eval_df = pd.DataFrame(evaluation_results).T | |
st.dataframe(eval_df) | |
st.subheader(t("summary_table")) | |
st.dataframe(df) | |
st.write(t("optimal_thresholds")) | |
st.json(optimal_thresholds) | |
# Graphique de comparaison des corrélations de Spearman | |
spearman_scores = [results['spearman_correlation'] for results in evaluation_results.values()] | |
plt.figure(figsize=(15, 8)) | |
plt.bar(evaluation_results.keys(), spearman_scores) | |
plt.title(t("spearman_comparison")) | |
plt.xlabel(t("methods")) | |
plt.ylabel(t("spearman_correlation")) | |
plt.xticks(rotation=90, ha='right') | |
plt.tight_layout() | |
st.pyplot(plt) | |
# Affichage des résultats pour chaque méthode | |
for column in df.columns: | |
if column != "Titre": | |
st.subheader(f"Résultats pour {column}") | |
df_method = df[["Titre", column]].sort_values(column, ascending=False) | |
threshold = find_optimal_threshold( | |
[item[0] for item in reference_data_valid], | |
[item[0] for item in reference_data_rejected], | |
df_method.set_index("Titre")[column].to_dict() | |
) | |
df_method = df_method[df_method[column] > threshold] | |
st.dataframe(df_method) | |
with tab2: | |
st.title(t("finetune_bart_title")) | |
num_epochs = st.number_input("Nombre d'époques", min_value=1, max_value=10, value=2) | |
lr = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-1, value=2e-5, format="%.6f", step=1e-5) | |
weight_decay = st.number_input("Poids de Décroissance (Weight Decay)", min_value=0.0, max_value=0.1, value=0.01, step=0.005) | |
batch_size = st.number_input("Taille du Batch", min_value=1, max_value=16, value=1) | |
start = st.slider("Score initial des données valides", min_value=0.0, max_value=1.0, value=0.9, step=0.01) | |
model_name = st.text_input("Nom du modèle", value='facebook/bart-large-mnli') | |
num_warmup_steps = st.number_input("Nombre d'étapes de Warmup", min_value=0, max_value=100, value=0) | |
# Bouton pour lancer le fine-tuning | |
if st.button("Lancer le fine-tuning"): | |
with st.spinner("Fine-tuning en cours..."): | |
finetune_bart(num_epochs=num_epochs, lr=lr, weight_decay=weight_decay, | |
batch_size=batch_size, model_name=model_name, output_model='./bart-large-ft', | |
num_warmup_steps=num_warmup_steps) | |
st.success("Fine-tuning terminé et modèle sauvegardé.") | |
with tab3: | |
st.title(t("finetune_embeddings_title")) | |
num_epochs_emb = st.number_input("Nombre d'époques (Embeddings)", min_value=1, max_value=100, value=10) | |
lr_emb = st.number_input("Learning Rate (Embeddings)", min_value=1e-6, max_value=1e-1, value=2e-5, format="%.6f", step=5e-6) | |
weight_decay_emb = st.number_input("Poids de Décroissance (Weight Decay) (Embeddings)", min_value=0.0, max_value=0.1, value=0.01, step=0.005) | |
batch_size_emb = st.number_input("Taille du Batch (Embeddings)", min_value=1, max_value=32, value=16) | |
start_emb = st.slider("Score initial des données valides (Embeddings)", min_value=0.0, max_value=1.0, value=0.9, step=0.01) | |
model_name_emb = st.selectbox("Modèle d'embeddings de base", ["nomic-ai/nomic-embed-text-v1", "all-MiniLM-L6-v2", "paraphrase-MiniLM-L6-v2"]) | |
margin_erb = st.slider("Marge (Embeddings)", min_value=0.0, max_value=1.0, value=0.5, step=0.01) | |
# Bouton pour lancer le fine-tuning des embeddings | |
if st.button("Lancer le fine-tuning des embeddings"): | |
with st.spinner("Fine-tuning des embeddings en cours..."): | |
finetune_embeddings(model_name=model_name_emb, output_model_name="./embeddings-ft", | |
num_epochs=num_epochs_emb, | |
learning_rate=lr_emb, | |
weight_decay=weight_decay_emb, | |
batch_size=batch_size_emb, | |
) | |
st.success("Fine-tuning des embeddings terminé et modèle sauvegardé.") | |
# Affichage de l'information sur le dispositif utilisé | |
device = "GPU (CUDA)" if torch.cuda.is_available() else "CPU" | |
st.sidebar.info(t("device_info").format(device)) | |