Spaces:
Running
Running
johannoriel
commited on
Commit
•
f34a6fd
1
Parent(s):
864cca3
Initial relase. Tested. Working
Browse files- app.py +203 -0
- bart_ft.py +89 -0
- embeddings_ft.py +95 -0
- plugins/__pycache__/common.cpython-310.pyc +0 -0
- plugins/__pycache__/ragllm.cpython-310.pyc +0 -0
- plugins/__pycache__/scansite.cpython-310.pyc +0 -0
- plugins/__pycache__/webrankings.cpython-310.pyc +0 -0
- plugins/common.py +44 -0
- plugins/ragllm.py +306 -0
- plugins/scansite.py +503 -0
- plugins/webrankings.py +285 -0
- webrankings_helper.py +410 -0
app.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import importlib
|
4 |
+
import streamlit as st
|
5 |
+
from typing import List, Dict, Any
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from global_vars import translations, t
|
8 |
+
|
9 |
+
# Constantes
|
10 |
+
CONFIG_FILE = "config.json"
|
11 |
+
|
12 |
+
def load_config() -> Dict[str, Any]:
|
13 |
+
if os.path.exists(CONFIG_FILE):
|
14 |
+
with open(CONFIG_FILE, 'r') as f:
|
15 |
+
return json.load(f)
|
16 |
+
return {}
|
17 |
+
|
18 |
+
def save_config(config: Dict[str, Any]):
|
19 |
+
with open(CONFIG_FILE, 'w') as f:
|
20 |
+
json.dump(config, f, indent=2)
|
21 |
+
|
22 |
+
# Fonction pour mettre à jour la langue
|
23 |
+
def set_lang(language):
|
24 |
+
st.session_state.lang = language
|
25 |
+
|
26 |
+
# Fonction de traduction
|
27 |
+
def t(key: str) -> str:
|
28 |
+
return translations[st.session_state.lang].get(key, key)
|
29 |
+
|
30 |
+
class Plugin:
|
31 |
+
def __init__(self, name, plugin_manager):
|
32 |
+
self.name = name
|
33 |
+
self.plugin_manager = plugin_manager
|
34 |
+
|
35 |
+
def get_config_fields(self) -> Dict[str, Any]:
|
36 |
+
return {}
|
37 |
+
|
38 |
+
def get_config_ui(self, config):
|
39 |
+
updated_config = {}
|
40 |
+
for field, params in self.get_config_fields().items():
|
41 |
+
if params['type'] == 'select':
|
42 |
+
updated_config[field] = st.selectbox(
|
43 |
+
params['label'],
|
44 |
+
options=[option[0] for option in params['options']],
|
45 |
+
format_func=lambda x: dict(params['options'])[x],
|
46 |
+
index=[option[0] for option in params['options']].index(config.get(field, params['default']))
|
47 |
+
)
|
48 |
+
elif params['type'] == 'textarea':
|
49 |
+
updated_config[field] = st.text_area(
|
50 |
+
params['label'],
|
51 |
+
value=config.get(field, params['default'])
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
updated_config[field] = st.text_input(
|
55 |
+
params['label'],
|
56 |
+
value=config.get(field, params['default']),
|
57 |
+
type="password" if field.startswith("pass") else "default"
|
58 |
+
)
|
59 |
+
return updated_config
|
60 |
+
|
61 |
+
def get_tabs(self) -> List[Dict[str, Any]]:
|
62 |
+
return []
|
63 |
+
|
64 |
+
def run(self, config: Dict[str, Any]):
|
65 |
+
pass
|
66 |
+
|
67 |
+
def get_sidebar_config_ui(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
68 |
+
return {}
|
69 |
+
|
70 |
+
class PluginManager:
|
71 |
+
def __init__(self):
|
72 |
+
self.plugins: Dict[str, Plugin] = {}
|
73 |
+
self.starred_plugins: Set[str] = set()
|
74 |
+
|
75 |
+
def load_plugins(self):
|
76 |
+
plugins_dir = 'plugins'
|
77 |
+
for filename in os.listdir(plugins_dir):
|
78 |
+
if filename.endswith('.py'):
|
79 |
+
module_name = filename[:-3]
|
80 |
+
module = importlib.import_module(f'plugins.{module_name}')
|
81 |
+
plugin_class = getattr(module, f'{module_name.capitalize()}Plugin')
|
82 |
+
self.plugins[module_name] = plugin_class(module_name, self)
|
83 |
+
|
84 |
+
def get_plugin(self, plugin_name: str) -> Plugin:
|
85 |
+
return self.plugins.get(plugin_name)
|
86 |
+
|
87 |
+
def get_all_config_ui(self, config: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
88 |
+
all_ui = {}
|
89 |
+
for plugin_name, plugin in sorted(self.plugins.items()):
|
90 |
+
with st.expander(f"{'⭐ ' if plugin_name in self.starred_plugins else ''}{t('configurations')} {plugin_name}"):
|
91 |
+
all_ui[plugin_name] = plugin.get_config_ui(config.get(plugin_name, {}))
|
92 |
+
if st.button(f"{'Unstar' if plugin_name in self.starred_plugins else 'Star'} {plugin_name}"):
|
93 |
+
if plugin_name in self.starred_plugins:
|
94 |
+
self.starred_plugins.remove(plugin_name)
|
95 |
+
else:
|
96 |
+
self.starred_plugins.add(plugin_name)
|
97 |
+
self.save_starred_plugins(config)
|
98 |
+
st.rerun()
|
99 |
+
return all_ui
|
100 |
+
|
101 |
+
def get_all_tabs(self) -> List[Dict[str, Any]]:
|
102 |
+
all_tabs = []
|
103 |
+
for plugin_name, plugin in sorted(self.plugins.items()):
|
104 |
+
tabs = plugin.get_tabs()
|
105 |
+
for tab in tabs:
|
106 |
+
tab['id'] = plugin_name
|
107 |
+
tab['starred'] = plugin_name in self.starred_plugins
|
108 |
+
all_tabs.extend(tabs)
|
109 |
+
return all_tabs
|
110 |
+
|
111 |
+
|
112 |
+
def load_starred_plugins(self, config: Dict[str, Any]):
|
113 |
+
self.starred_plugins = set(config.get('starred_plugins', []))
|
114 |
+
|
115 |
+
def save_starred_plugins(self, config: Dict[str, Any]):
|
116 |
+
config['starred_plugins'] = list(self.starred_plugins)
|
117 |
+
save_config(config)
|
118 |
+
|
119 |
+
def run_plugin(self, plugin_name: str, config: Dict[str, Any]):
|
120 |
+
if plugin_name in self.plugins:
|
121 |
+
self.plugins[plugin_name].run(config)
|
122 |
+
|
123 |
+
def save_config(self, config):
|
124 |
+
save_config(config)
|
125 |
+
|
126 |
+
def main():
|
127 |
+
st.set_page_config(page_title="Veille", layout="wide")
|
128 |
+
# Initialisation du gestionnaire de plugins
|
129 |
+
plugin_manager = PluginManager()
|
130 |
+
plugin_manager.load_plugins()
|
131 |
+
|
132 |
+
# Chargement de la configuration
|
133 |
+
config = load_config()
|
134 |
+
plugin_manager.load_starred_plugins(config)
|
135 |
+
|
136 |
+
# Initialisation de la langue dans st.session_state
|
137 |
+
if 'lang' not in st.session_state:
|
138 |
+
st.session_state.lang = config['common']['language']
|
139 |
+
st.title(t("page_title"))
|
140 |
+
|
141 |
+
load_dotenv()
|
142 |
+
LLM_KEY = os.getenv("LLM_API_KEY")
|
143 |
+
config['llm_key'] = LLM_KEY
|
144 |
+
|
145 |
+
# Création des onglets avec des identifiants uniques
|
146 |
+
tabs = [{"id": "configurations", "name": t("configurations")}] + [{"id": tab['plugin'], "name": tab['name'], "starred" : tab['starred']} for tab in plugin_manager.get_all_tabs()]
|
147 |
+
|
148 |
+
# Gestion de la langue
|
149 |
+
if 'lang' not in st.session_state:
|
150 |
+
st.session_state.lang = "fr"
|
151 |
+
|
152 |
+
new_lang = st.sidebar.selectbox("Choose your language / Choisissez votre langue", options=["en", "fr"], index=["en", "fr"].index(st.session_state.lang), key="lang_selector")
|
153 |
+
|
154 |
+
if new_lang != st.session_state.lang:
|
155 |
+
st.session_state.lang = new_lang
|
156 |
+
st.rerun()
|
157 |
+
|
158 |
+
# Ajout des éléments de configuration de la sidebar pour chaque plugin
|
159 |
+
for plugin_name, plugin in plugin_manager.plugins.items():
|
160 |
+
sidebar_config = plugin.get_sidebar_config_ui(config.get(plugin_name, {}))
|
161 |
+
if sidebar_config:
|
162 |
+
#st.sidebar.markdown(f"**{plugin_name} Configuration**")
|
163 |
+
for key, value in sidebar_config.items():
|
164 |
+
config.setdefault(plugin_name, {})[key] = value
|
165 |
+
|
166 |
+
# Gestion de l'onglet sélectionné
|
167 |
+
if 'selected_tab_id' not in st.session_state:
|
168 |
+
st.session_state.selected_tab_id = "configurations"
|
169 |
+
|
170 |
+
# Sort tabs alphabetically, with starred tabs first
|
171 |
+
sorted_tabs = sorted(tabs, key=lambda x: (not x.get('starred', False), x['name']))
|
172 |
+
tab_names = [f"{'⭐ ' if tab.get('starred', False) else ''}{tab['name']}" for tab in sorted_tabs]
|
173 |
+
|
174 |
+
selected_tab_index = [tab["id"] for tab in sorted_tabs].index(st.session_state.selected_tab_id)
|
175 |
+
selected_tab = st.sidebar.radio(t("navigation"), tab_names, index=selected_tab_index, key="tab_selector")
|
176 |
+
|
177 |
+
new_selected_tab_id = next(tab["id"] for tab in sorted_tabs if f"{'⭐ ' if tab.get('starred', False) else ''}{tab['name']}" == selected_tab)
|
178 |
+
|
179 |
+
if new_selected_tab_id != st.session_state.selected_tab_id:
|
180 |
+
st.session_state.selected_tab_id = new_selected_tab_id
|
181 |
+
st.rerun()
|
182 |
+
|
183 |
+
if st.session_state.selected_tab_id == "configurations":
|
184 |
+
st.header(t("configurations"))
|
185 |
+
all_config_ui = plugin_manager.get_all_config_ui(config)
|
186 |
+
|
187 |
+
for plugin_name, ui_config in all_config_ui.items():
|
188 |
+
with st.expander(f"{t('configurations')} {plugin_name}"):
|
189 |
+
config[plugin_name] = ui_config
|
190 |
+
|
191 |
+
if st.button(t("save_button")):
|
192 |
+
save_config(config)
|
193 |
+
st.success(t("success_message"))
|
194 |
+
|
195 |
+
else:
|
196 |
+
# Exécution du plugin correspondant à l'onglet sélectionné
|
197 |
+
for tab in plugin_manager.get_all_tabs():
|
198 |
+
if tab['plugin'] == st.session_state.selected_tab_id:
|
199 |
+
plugin_manager.run_plugin(tab['plugin'], config)
|
200 |
+
break
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
main()
|
bart_ft.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
from transformers import BartForSequenceClassification, BartTokenizer, get_linear_schedule_with_warmup
|
4 |
+
from transformers import AdamW
|
5 |
+
from tqdm import tqdm
|
6 |
+
import gc
|
7 |
+
from plugins.scansite import ScansitePlugin # Assurez-vous que l'import est correct
|
8 |
+
|
9 |
+
torch.cuda.empty_cache()
|
10 |
+
|
11 |
+
# Définition du dataset personnalisé
|
12 |
+
class PreferenceDataset(Dataset):
|
13 |
+
def __init__(self, data, tokenizer, max_length=128):
|
14 |
+
self.data = data
|
15 |
+
self.tokenizer = tokenizer
|
16 |
+
self.max_length = max_length
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
return len(self.data)
|
20 |
+
|
21 |
+
def __getitem__(self, idx):
|
22 |
+
_, title, score = self.data[idx]
|
23 |
+
encoding = self.tokenizer(title, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
|
24 |
+
# Convertir le score en label binaire (0 ou 1)
|
25 |
+
label = 1 if score > 0 else 0
|
26 |
+
return {key: val.squeeze(0) for key, val in encoding.items()}, torch.tensor(label, dtype=torch.long)
|
27 |
+
|
28 |
+
# Fonction principale de finetuning
|
29 |
+
def finetune(num_epochs=2, lr=2e-5, weight_decay=0.01, batch_size=1, model_name='facebook/bart-large-mnli', output_model='./bart-large-ft', num_warmup_steps=0):
|
30 |
+
print(f"Fine-tuning parameters:\n"
|
31 |
+
f"num_epochs: {num_epochs}\n"
|
32 |
+
f"learning rate (lr): {lr}\n"
|
33 |
+
f"weight_decay: {weight_decay}\n"
|
34 |
+
f"batch_size: {batch_size}\n"
|
35 |
+
f"model_name: {model_name}\n"
|
36 |
+
f"num_warmup_steps: {num_warmup_steps}")
|
37 |
+
|
38 |
+
# Récupérer les données de référence
|
39 |
+
scansite_plugin = ScansitePlugin("scansite", None) # Vous devrez peut-être ajuster ceci en fonction de votre structure
|
40 |
+
reference_data_valid, reference_data_rejected = scansite_plugin.get_reference_data()
|
41 |
+
|
42 |
+
# Combiner les données valides et rejetées
|
43 |
+
all_data = reference_data_valid + [(url, title, 0) for url, title in reference_data_rejected]
|
44 |
+
|
45 |
+
tokenizer = BartTokenizer.from_pretrained(model_name)
|
46 |
+
model = BartForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
|
47 |
+
|
48 |
+
dataset = PreferenceDataset(all_data, tokenizer)
|
49 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
50 |
+
|
51 |
+
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
52 |
+
|
53 |
+
total_steps = len(dataloader) * num_epochs
|
54 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)
|
55 |
+
|
56 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
57 |
+
model.to(device)
|
58 |
+
model.gradient_checkpointing_enable()
|
59 |
+
|
60 |
+
for epoch in range(num_epochs):
|
61 |
+
model.train()
|
62 |
+
for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
|
63 |
+
optimizer.zero_grad()
|
64 |
+
|
65 |
+
inputs, labels = batch
|
66 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
67 |
+
labels = labels.to(device)
|
68 |
+
|
69 |
+
outputs = model(**inputs, labels=labels)
|
70 |
+
loss = outputs.loss
|
71 |
+
|
72 |
+
loss.backward()
|
73 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
74 |
+
optimizer.step()
|
75 |
+
scheduler.step()
|
76 |
+
|
77 |
+
del inputs, outputs, labels
|
78 |
+
torch.cuda.empty_cache()
|
79 |
+
gc.collect()
|
80 |
+
#print(f"Finetuning en cours round {epoch}/{num_epochs}")
|
81 |
+
print("Finetuning terminé sauvegarde en cours.")
|
82 |
+
model.save_pretrained(output_model)
|
83 |
+
tokenizer.save_pretrained(output_model)
|
84 |
+
|
85 |
+
print("Finetuning terminé et modèle sauvegardé.")
|
86 |
+
|
87 |
+
# Appel par défaut si le script est exécuté directement
|
88 |
+
if __name__ == "__main__":
|
89 |
+
finetune()
|
embeddings_ft.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
from sentence_transformers import SentenceTransformer, losses
|
4 |
+
from tqdm import tqdm
|
5 |
+
import gc
|
6 |
+
from plugins.scansite import ScansitePlugin
|
7 |
+
|
8 |
+
torch.cuda.empty_cache()
|
9 |
+
|
10 |
+
class PreferenceDataset(Dataset):
|
11 |
+
def __init__(self, data, tokenizer, max_length=128):
|
12 |
+
self.data = data
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
self.max_length = max_length
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.data)
|
18 |
+
|
19 |
+
def __getitem__(self, idx):
|
20 |
+
url, title, score = self.data[idx]
|
21 |
+
encoded = self.tokenizer(title, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")
|
22 |
+
return {key: val.squeeze(0) for key, val in encoded.items()}, torch.tensor(score, dtype=torch.float)
|
23 |
+
|
24 |
+
def collate_fn(batch):
|
25 |
+
input_ids = torch.stack([item[0]['input_ids'] for item in batch])
|
26 |
+
attention_mask = torch.stack([item[0]['attention_mask'] for item in batch])
|
27 |
+
scores = torch.stack([item[1] for item in batch])
|
28 |
+
return {'input_ids': input_ids, 'attention_mask': attention_mask}, scores
|
29 |
+
|
30 |
+
def finetune(model_name='nomic-ai/nomic-embed-text-v1', output_model_name="embeddings-ft", num_epochs=2, learning_rate=2e-5, weight_decay=0.01, batch_size=8, num_warmup_steps=0):
|
31 |
+
print(f"Fine-tuning parameters:\n"
|
32 |
+
f"num_epochs: {num_epochs}\n"
|
33 |
+
f"learning rate (lr): {learning_rate}\n"
|
34 |
+
f"weight_decay: {weight_decay}\n"
|
35 |
+
f"batch_size: {batch_size}\n"
|
36 |
+
f"model_name: {model_name}\n"
|
37 |
+
f"num_warmup_steps: {num_warmup_steps}")
|
38 |
+
|
39 |
+
scansite_plugin = ScansitePlugin("scansite", None)
|
40 |
+
reference_data_valid, reference_data_rejected = scansite_plugin.get_reference_data()
|
41 |
+
|
42 |
+
valid_data_with_scores = [(url, title, (score - 1) / 8 + 0.5) for url, title, score in reference_data_valid]
|
43 |
+
rejected_data_with_scores = [(url, title, 0.0) for url, title in reference_data_rejected]
|
44 |
+
|
45 |
+
all_data = valid_data_with_scores + rejected_data_with_scores
|
46 |
+
|
47 |
+
model = SentenceTransformer(model_name, trust_remote_code=True)
|
48 |
+
tokenizer = model.tokenizer
|
49 |
+
|
50 |
+
dataset = PreferenceDataset(all_data, tokenizer)
|
51 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
|
52 |
+
|
53 |
+
loss_function = torch.nn.MSELoss()
|
54 |
+
|
55 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
56 |
+
|
57 |
+
total_steps = len(dataloader) * num_epochs
|
58 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=total_steps)
|
59 |
+
|
60 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
61 |
+
model.to(device)
|
62 |
+
|
63 |
+
for epoch in range(num_epochs):
|
64 |
+
model.train()
|
65 |
+
for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
|
66 |
+
input_data, scores = batch
|
67 |
+
input_data = {k: v.to(device) for k, v in input_data.items()}
|
68 |
+
scores = scores.to(device)
|
69 |
+
|
70 |
+
optimizer.zero_grad()
|
71 |
+
|
72 |
+
embeddings = model(input_data)['sentence_embedding']
|
73 |
+
|
74 |
+
# Calcul de la similarité cosinus
|
75 |
+
embeddings_norm = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
76 |
+
cosine_similarities = torch.sum(embeddings_norm, dim=1)
|
77 |
+
|
78 |
+
# Calcul de la perte
|
79 |
+
loss = loss_function(cosine_similarities, scores)
|
80 |
+
|
81 |
+
loss.backward()
|
82 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
83 |
+
optimizer.step()
|
84 |
+
scheduler.step()
|
85 |
+
|
86 |
+
del embeddings, cosine_similarities
|
87 |
+
torch.cuda.empty_cache()
|
88 |
+
gc.collect()
|
89 |
+
|
90 |
+
model.save(output_model_name)
|
91 |
+
|
92 |
+
print("Finetuning terminé et modèle sauvegardé.")
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
finetune()
|
plugins/__pycache__/common.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|
plugins/__pycache__/ragllm.cpython-310.pyc
ADDED
Binary file (11.7 kB). View file
|
|
plugins/__pycache__/scansite.cpython-310.pyc
ADDED
Binary file (18.5 kB). View file
|
|
plugins/__pycache__/webrankings.cpython-310.pyc
ADDED
Binary file (10.4 kB). View file
|
|
plugins/common.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from global_vars import t, translations
|
2 |
+
from app import Plugin
|
3 |
+
import streamlit as st
|
4 |
+
import torch
|
5 |
+
|
6 |
+
# Ajout des traductions spécifiques à ce plugin
|
7 |
+
translations["en"].update({
|
8 |
+
"work_directory": "Work Directory",
|
9 |
+
"page_title": "News watcher",
|
10 |
+
})
|
11 |
+
translations["fr"].update({
|
12 |
+
"work_directory": "Répertoire de travail",
|
13 |
+
"page_title": "Outil de veille",
|
14 |
+
})
|
15 |
+
|
16 |
+
class CommonPlugin(Plugin):
|
17 |
+
def get_config_fields(self):
|
18 |
+
return {
|
19 |
+
"work_directory": {
|
20 |
+
"type": "text",
|
21 |
+
"label": t("work_directory"),
|
22 |
+
"default": "/home/joriel/Vidéos"
|
23 |
+
},
|
24 |
+
"language": {
|
25 |
+
"type": "select",
|
26 |
+
"label": t("preferred_language"),
|
27 |
+
"options": [("fr", "Français"), ("en", "Anglais")],
|
28 |
+
"default": "fr"
|
29 |
+
},
|
30 |
+
}
|
31 |
+
|
32 |
+
def get_tabs(self):
|
33 |
+
return [{"name": "Commun", "plugin": "common"}]
|
34 |
+
|
35 |
+
def run(self, config):
|
36 |
+
st.header("Common Plugin")
|
37 |
+
st.write(f"{t('work_directory')}: {config['common']['work_directory']}")
|
38 |
+
torch.cuda.empty_cache()
|
39 |
+
st.write("CUDA memory reset")
|
40 |
+
|
41 |
+
def remove_quotes(s):
|
42 |
+
if s.startswith('"') and s.endswith('"'):
|
43 |
+
return s[1:-1]
|
44 |
+
return s
|
plugins/ragllm.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from global_vars import translations, t
|
2 |
+
from app import Plugin
|
3 |
+
import streamlit as st
|
4 |
+
import yaml
|
5 |
+
from litellm import completion, embedding
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, manhattan_distances
|
8 |
+
import os
|
9 |
+
from typing import List, Dict, Any
|
10 |
+
import requests
|
11 |
+
import torch
|
12 |
+
from transformers import AutoTokenizer, AutoModel
|
13 |
+
|
14 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
+
MAX_LENGTH = 512
|
16 |
+
CHUNK_SIZE = 200 # Nombre de mots par chunk
|
17 |
+
|
18 |
+
def mean_pooling(model_output, attention_mask):
|
19 |
+
token_embeddings = model_output[0]
|
20 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
21 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
22 |
+
|
23 |
+
# Ajout des traductions spécifiques à ce plugin
|
24 |
+
translations["en"].update({
|
25 |
+
"rag_plugin_loaded": "RAG LLM Plugin loaded",
|
26 |
+
"rag_enter_text": "Enter RAG text:",
|
27 |
+
"rag_enter_question": "Enter your question:",
|
28 |
+
"rag_button_get_answer": "Get an answer",
|
29 |
+
"rag_success_text_processed": "RAG text processed successfully!",
|
30 |
+
"rag_warning_enter_text": "Please enter RAG text.",
|
31 |
+
"rag_warning_process_text_first": "Please process the RAG text first.",
|
32 |
+
"rag_warning_enter_question": "Please enter a question.",
|
33 |
+
"rag_answer": "Answer:",
|
34 |
+
"rag_citations": "Citations:",
|
35 |
+
"rag_model_provider": "Model Provider",
|
36 |
+
"rag_llm_model": "LLM Model",
|
37 |
+
"rag_embedder_model": "Embedding Model",
|
38 |
+
"rag_similarity_method": "Similarity Method",
|
39 |
+
"rag_llm_sys_prompt": "System prompt for LLM",
|
40 |
+
"rag_chunk_size": "Chunk size",
|
41 |
+
"rag_top_k_chunks": "Number of chunks to use",
|
42 |
+
"rag_default_sys_prompt": "You are an AI assistant. Your task is to analyze the provided context and answer questions based ONLY on this context. If the information is not in the context, clearly state that.",
|
43 |
+
"rag_error_fetching_models_ollama": "Error fetching Ollama models: ",
|
44 |
+
"rag_error_calling_llm": "Error calling LLM: ",
|
45 |
+
"rag_processing" : "Processing...",
|
46 |
+
})
|
47 |
+
|
48 |
+
translations["fr"].update({
|
49 |
+
"rag_plugin_loaded": "Plugin RAG LLM chargé",
|
50 |
+
"rag_enter_text": "Entrez le texte RAG :",
|
51 |
+
"rag_enter_question": "Entrez votre question :",
|
52 |
+
"rag_button_get_answer": "Obtenir une réponse",
|
53 |
+
"rag_success_text_processed": "Texte RAG traité avec succès!",
|
54 |
+
"rag_warning_enter_text": "Veuillez entrer du texte RAG.",
|
55 |
+
"rag_warning_process_text_first": "Veuillez d'abord traiter le texte RAG.",
|
56 |
+
"rag_warning_enter_question": "Veuillez entrer une question.",
|
57 |
+
"rag_answer": "Réponse :",
|
58 |
+
"rag_citations": "Citations :",
|
59 |
+
"rag_model_provider": "Fournisseur de modèle",
|
60 |
+
"rag_llm_model": "Modèle LLM",
|
61 |
+
"rag_embedder_model": "Modèle d'embedding",
|
62 |
+
"rag_similarity_method": "Méthode de similarité",
|
63 |
+
"rag_llm_sys_prompt": "Prompt système pour le LLM",
|
64 |
+
"rag_chunk_size": "Taille des chunks",
|
65 |
+
"rag_top_k_chunks": "Nombre de chunks à utiliser",
|
66 |
+
"rag_default_sys_prompt": "Tu es un assistant IA. Ta tâche est d'analyser le contexte fourni et de répondre aux questions en te basant UNIQUEMENT sur ce contexte. Si l'information n'est pas dans le contexte, dis-le clairement.",
|
67 |
+
"rag_error_fetching_models_ollama": "Erreur lors de la récupération des modèles Ollama : ",
|
68 |
+
"rag_error_calling_llm": "Erreur lors de l'appel au LLM : ",
|
69 |
+
"rag_processing" : "En cours de traitement...",
|
70 |
+
})
|
71 |
+
|
72 |
+
class RagllmPlugin(Plugin):
|
73 |
+
def __init__(self, name: str, plugin_manager):
|
74 |
+
super().__init__(name, plugin_manager)
|
75 |
+
self.config = self.load_llm_config()
|
76 |
+
self.embeddings = None
|
77 |
+
self.chunks = None
|
78 |
+
|
79 |
+
def load_llm_config(self) -> Dict:
|
80 |
+
with open('.llm-config.yml', 'r') as file:
|
81 |
+
return yaml.safe_load(file)
|
82 |
+
|
83 |
+
def get_tabs(self):
|
84 |
+
return [{"name": "RAG", "plugin": "ragllm"}]
|
85 |
+
|
86 |
+
def get_config_fields(self):
|
87 |
+
return {
|
88 |
+
"provider": {
|
89 |
+
"type": "select",
|
90 |
+
"label": t("rag_model_provider"),
|
91 |
+
"options": [("ollama", "Ollama"), ("groq", "Groq")],
|
92 |
+
"default": "ollama"
|
93 |
+
},
|
94 |
+
"llm_model": {
|
95 |
+
"type": "select",
|
96 |
+
"label": t("rag_llm_model"),
|
97 |
+
"options": [("none", "À charger...")],
|
98 |
+
"default": "ollama/qwen2"
|
99 |
+
},
|
100 |
+
"embedder": {
|
101 |
+
"type": "select",
|
102 |
+
"label": t("rag_embedder_model"),
|
103 |
+
"options": [
|
104 |
+
("sentence-transformers/all-MiniLM-L6-v2", "all-MiniLM-L6-v2"),
|
105 |
+
("nomic-ai/nomic-embed-text-v1.5", "nomic-embed-text-v1.5")
|
106 |
+
],
|
107 |
+
"default": "sentence-transformers/all-MiniLM-L6-v2"
|
108 |
+
},
|
109 |
+
"similarity_method": {
|
110 |
+
"type": "select",
|
111 |
+
"label": t("rag_similarity_method"),
|
112 |
+
"options": [
|
113 |
+
("cosine", "Cosinus"),
|
114 |
+
("euclidean", "Distance euclidienne"),
|
115 |
+
("manhattan", "Distance de Manhattan")
|
116 |
+
],
|
117 |
+
"default": "cosine"
|
118 |
+
},
|
119 |
+
"llm_sys_prompt": {
|
120 |
+
"type": "textarea",
|
121 |
+
"label": t("rag_llm_sys_prompt"),
|
122 |
+
"default": t("rag_default_sys_prompt")
|
123 |
+
},
|
124 |
+
"chunk_size": {
|
125 |
+
"type": "number",
|
126 |
+
"label": t("rag_chunk_size"),
|
127 |
+
"default": 200
|
128 |
+
},
|
129 |
+
"top_k": {
|
130 |
+
"type": "number",
|
131 |
+
"label": t("rag_top_k_chunks"),
|
132 |
+
"default": 3
|
133 |
+
}
|
134 |
+
}
|
135 |
+
|
136 |
+
def get_config_ui(self, config):
|
137 |
+
updated_config = {}
|
138 |
+
for field, params in self.get_config_fields().items():
|
139 |
+
if params['type'] == 'select':
|
140 |
+
if field == 'llm_model':
|
141 |
+
provider = config.get('provider', 'ollama')
|
142 |
+
models = self.get_available_models(provider)
|
143 |
+
try:
|
144 |
+
default_index = models.index(config.get(field, params['default']))
|
145 |
+
except ValueError:
|
146 |
+
default_index = 0
|
147 |
+
updated_config[field] = st.selectbox(
|
148 |
+
params['label'],
|
149 |
+
options=models,
|
150 |
+
index=default_index
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
options_list = [option[0] for option in params['options']]
|
154 |
+
try:
|
155 |
+
default_index = options_list.index(config.get(field, params['default']))
|
156 |
+
except ValueError:
|
157 |
+
default_index = 0
|
158 |
+
updated_config[field] = st.selectbox(
|
159 |
+
params['label'],
|
160 |
+
options=options_list,
|
161 |
+
format_func=lambda x: dict(params['options'])[x],
|
162 |
+
index=default_index
|
163 |
+
)
|
164 |
+
elif params['type'] == 'textarea':
|
165 |
+
updated_config[field] = st.text_area(
|
166 |
+
params['label'],
|
167 |
+
value=config.get(field, params['default'])
|
168 |
+
)
|
169 |
+
elif params['type'] == 'number':
|
170 |
+
updated_config[field] = st.number_input(
|
171 |
+
params['label'],
|
172 |
+
value=int(config.get(field, params['default'])),
|
173 |
+
step=1
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
updated_config[field] = st.text_input(
|
177 |
+
params['label'],
|
178 |
+
value=config.get(field, params['default'])
|
179 |
+
)
|
180 |
+
return updated_config
|
181 |
+
|
182 |
+
def get_sidebar_config_ui(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
183 |
+
available_models = self.get_available_models('ollama') + self.get_available_models('groq')
|
184 |
+
default_model = config.get('llm_model', available_models[0] if available_models else None)
|
185 |
+
selected_model = st.sidebar.selectbox(
|
186 |
+
t("rag_llm_model"),
|
187 |
+
options=available_models,
|
188 |
+
index=available_models.index(default_model) if default_model in available_models else 0,
|
189 |
+
key="ragllm_llm_model"
|
190 |
+
)
|
191 |
+
return {"llm_model": selected_model}
|
192 |
+
|
193 |
+
def get_available_models(self, provider: str) -> List[str]:
|
194 |
+
if provider == 'ollama':
|
195 |
+
try:
|
196 |
+
response = requests.get("http://localhost:11434/api/tags")
|
197 |
+
models = response.json()["models"]
|
198 |
+
return [f"ollama/{model['name']}" for model in models] + ["ollama/qwen2"]
|
199 |
+
except Exception as e:
|
200 |
+
st.error(f"{t('rag_error_fetching_models_ollama')}{str(e)}")
|
201 |
+
return ["ollama/qwen2"]
|
202 |
+
elif provider == 'groq':
|
203 |
+
return ["groq/llama3-70b-8192", "groq/mixtral-8x7b-32768"]
|
204 |
+
else:
|
205 |
+
return ["none"]
|
206 |
+
|
207 |
+
def process_rag_text(self, rag_text: str, chunk_size: int, embedder):
|
208 |
+
rag_text = rag_text.replace('\\n', ' ').replace('\\\'', "'")
|
209 |
+
mots = rag_text.split()
|
210 |
+
self.chunks = [' '.join(mots[i:i+chunk_size]) for i in range(0, len(mots), chunk_size)]
|
211 |
+
self.embeddings = np.vstack([self.get_embedding(c, embedder) for c in self.chunks])
|
212 |
+
|
213 |
+
def get_embedding(self, text: str, model: str) -> np.ndarray:
|
214 |
+
tokenizer = AutoTokenizer.from_pretrained(model)
|
215 |
+
model = AutoModel.from_pretrained(model, trust_remote_code=True).to(DEVICE)
|
216 |
+
inputs = tokenizer(text, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(DEVICE)
|
217 |
+
with torch.no_grad():
|
218 |
+
model_output = model(**inputs)
|
219 |
+
return mean_pooling(model_output, inputs['attention_mask']).cpu().numpy()
|
220 |
+
|
221 |
+
def calculate_similarity(self, query_embedding: np.ndarray, method: str) -> np.ndarray:
|
222 |
+
if method == 'cosine':
|
223 |
+
return cosine_similarity(query_embedding.reshape(1, -1), self.embeddings)[0]
|
224 |
+
elif method == 'euclidean':
|
225 |
+
return -euclidean_distances(query_embedding.reshape(1, -1), self.embeddings)[0]
|
226 |
+
elif method == 'manhattan':
|
227 |
+
return -manhattan_distances(query_embedding.reshape(1, -1), self.embeddings)[0]
|
228 |
+
else:
|
229 |
+
raise ValueError("Méthode de similarité non reconnue")
|
230 |
+
|
231 |
+
def get_context(self, query: str, config: Dict[str, Any]) -> tuple:
|
232 |
+
query_embedding = self.get_embedding(query, config['ragllm']['embedder'])
|
233 |
+
similarities = self.calculate_similarity(query_embedding, config['ragllm']['similarity_method'])
|
234 |
+
top_indices = np.argsort(similarities)[-config['ragllm']['top_k']:][::-1]
|
235 |
+
context = "\n\n".join([self.chunks[i] for i in top_indices])
|
236 |
+
return context, [self.chunks[i] for i in top_indices]
|
237 |
+
|
238 |
+
def call_llm(self, prompt: str, sysprompt: str) -> str:
|
239 |
+
try:
|
240 |
+
llm_model = st.session_state.ragllm_llm_model
|
241 |
+
#print(f"---------------------------------------\nCalling LLM {llm_model} \n with sysprompt {sysprompt} \n and prompt {prompt} \n and context len of {len(context)}")
|
242 |
+
messages = [
|
243 |
+
{"role": "system", "content": sysprompt},
|
244 |
+
{"role": "user", "content": prompt}
|
245 |
+
]
|
246 |
+
response = completion(model=llm_model, messages=messages)
|
247 |
+
return response['choices'][0]['message']['content']
|
248 |
+
except Exception as e:
|
249 |
+
return f"{t('rag_error_calling_llm')}{str(e)}"
|
250 |
+
|
251 |
+
def free_llm(self):
|
252 |
+
try:
|
253 |
+
llm_model = st.session_state.ragllm_llm_model
|
254 |
+
if llm_model.startswith("ollama/"):
|
255 |
+
ollama_model = llm_model.split("/")[1]
|
256 |
+
response = requests.post(
|
257 |
+
"http://localhost:11434/api/generate",
|
258 |
+
json={
|
259 |
+
"model": ollama_model,
|
260 |
+
"prompt": "bye",
|
261 |
+
"keep_alive": 0
|
262 |
+
}
|
263 |
+
)
|
264 |
+
return response.json()['response']
|
265 |
+
except Exception as e:
|
266 |
+
return f"{t('rag_error_calling_llm')}{str(e)}"
|
267 |
+
|
268 |
+
def process_with_llm(self, prompt: str, sysprompt: str, context: str) -> str:
|
269 |
+
return self.call_llm(f"Contexte : {context}\n\nQuestion : {prompt}", sysprompt)
|
270 |
+
|
271 |
+
def run(self, config):
|
272 |
+
st.write(t("rag_plugin_loaded"))
|
273 |
+
|
274 |
+
# Initialiser rag_text avec la valeur de session_state si elle existe, sinon utiliser une chaîne vide
|
275 |
+
if 'rag_text' not in st.session_state:
|
276 |
+
st.session_state.rag_text = ""
|
277 |
+
if 'rag_question' not in st.session_state:
|
278 |
+
st.session_state.rag_question = "Question"
|
279 |
+
|
280 |
+
rag_text = st.text_area(t("rag_enter_text"), height=200, value=st.session_state.rag_text, key="rag_text_key")
|
281 |
+
user_prompt = st.text_area(t("rag_enter_question"), value=st.session_state.rag_question, key="rag_prompt_key")
|
282 |
+
st.session_state.rag_text = rag_text # Mettre à jour la valeur dans session_state
|
283 |
+
st.session_state.rag_question = user_prompt
|
284 |
+
|
285 |
+
if st.button(t("rag_button_get_answer"), key="get_answer_button"):
|
286 |
+
|
287 |
+
with st.spinner(t("rag_processing")):
|
288 |
+
if rag_text:
|
289 |
+
self.process_rag_text(rag_text, config['ragllm']['chunk_size'], config['ragllm']['embedder'])
|
290 |
+
st.success(t("rag_success_text_processed"))
|
291 |
+
else:
|
292 |
+
st.warning(t("rag_warning_enter_text"))
|
293 |
+
if user_prompt and self.embeddings is not None:
|
294 |
+
context, citations = self.get_context(user_prompt, config)
|
295 |
+
response = self.process_with_llm(user_prompt, config['ragllm']['llm_sys_prompt'], context)
|
296 |
+
|
297 |
+
st.write(t("rag_answer"))
|
298 |
+
st.write(response)
|
299 |
+
|
300 |
+
st.write(t("rag_citations"))
|
301 |
+
for i, citation in enumerate(citations, 1):
|
302 |
+
st.write(f"{i}. {citation[:100]}...")
|
303 |
+
elif self.embeddings is None:
|
304 |
+
st.warning(t("rag_warning_process_text_first"))
|
305 |
+
else:
|
306 |
+
st.warning(t("rag_warning_enter_question"))
|
plugins/scansite.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app import Plugin
|
2 |
+
import streamlit as st
|
3 |
+
import sqlite3
|
4 |
+
import requests
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
from datetime import datetime
|
7 |
+
import ollama
|
8 |
+
from global_vars import t, translations
|
9 |
+
|
10 |
+
# Ajout des traductions spécifiques à ce plugin
|
11 |
+
translations["en"].update({
|
12 |
+
"scansite_title": "News Aggregator",
|
13 |
+
"total_links": "Total number of links",
|
14 |
+
"annotated_links": "Number of annotated links",
|
15 |
+
"known_tags": "Known tags",
|
16 |
+
"reset_database": "Reset database",
|
17 |
+
"database_reset_success": "Database reset successfully",
|
18 |
+
"launch_scan": "Launch scan",
|
19 |
+
"scan_complete": "Scan complete",
|
20 |
+
"no_articles": "No articles to display.",
|
21 |
+
"page": "Page",
|
22 |
+
"previous_page": "Previous page",
|
23 |
+
"next_page": "Next page",
|
24 |
+
"new_articles": "New Articles",
|
25 |
+
"rated_articles": "Rated Articles",
|
26 |
+
"clicked_not_rated": "Clicked but not rated Articles",
|
27 |
+
"tagged_articles": "Tagged Articles",
|
28 |
+
"ignored_articles": "Ignored Articles",
|
29 |
+
"excluded_articles": "Excluded Articles",
|
30 |
+
"rating": "Rating",
|
31 |
+
"tags": "Tags",
|
32 |
+
"exclude": "Exclude",
|
33 |
+
"sources": "Sources",
|
34 |
+
"update": "Update",
|
35 |
+
"delete": "Delete",
|
36 |
+
"add_new_source": "Add a new source (URL)",
|
37 |
+
"add_source": "Add source",
|
38 |
+
"new_tag": "New tag",
|
39 |
+
"new_tag_description": "New tag description",
|
40 |
+
"add_tag": "Add tag",
|
41 |
+
"work_directory": "Work Directory",
|
42 |
+
})
|
43 |
+
|
44 |
+
translations["fr"].update({
|
45 |
+
"scansite_title": "Agrégateur de Nouvelles",
|
46 |
+
"total_links": "Nombre total de liens",
|
47 |
+
"annotated_links": "Nombre de liens annotés",
|
48 |
+
"known_tags": "Tags connus",
|
49 |
+
"reset_database": "Réinitialiser la base de données",
|
50 |
+
"database_reset_success": "Base de données réinitialisée",
|
51 |
+
"launch_scan": "Lancer le scan",
|
52 |
+
"scan_complete": "Scan terminé",
|
53 |
+
"no_articles": "Aucun article à afficher.",
|
54 |
+
"page": "Page",
|
55 |
+
"previous_page": "Page précédente",
|
56 |
+
"next_page": "Page suivante",
|
57 |
+
"new_articles": "Nouveaux Articles",
|
58 |
+
"rated_articles": "Articles Notés",
|
59 |
+
"clicked_not_rated": "Articles Cliqués non notés",
|
60 |
+
"tagged_articles": "Articles Tagués",
|
61 |
+
"ignored_articles": "Articles Ignorés",
|
62 |
+
"excluded_articles": "Articles Exclus",
|
63 |
+
"rating": "Note",
|
64 |
+
"tags": "Tags",
|
65 |
+
"exclude": "Exclure",
|
66 |
+
"sources": "Sources",
|
67 |
+
"update": "Mettre à jour",
|
68 |
+
"delete": "Supprimer",
|
69 |
+
"add_new_source": "Ajouter une nouvelle source (URL)",
|
70 |
+
"add_source": "Ajouter source",
|
71 |
+
"new_tag": "Nouveau tag",
|
72 |
+
"new_tag_description": "Description du nouveau tag",
|
73 |
+
"add_tag": "Ajouter tag",
|
74 |
+
"work_directory": "Répertoire de travail",
|
75 |
+
})
|
76 |
+
|
77 |
+
class ScansitePlugin(Plugin):
|
78 |
+
def __init__(self, name, plugin_manager):
|
79 |
+
super().__init__(name, plugin_manager)
|
80 |
+
self.conn = self.get_connection()
|
81 |
+
self.c = self.conn.cursor()
|
82 |
+
self.init_db()
|
83 |
+
|
84 |
+
def get_connection(self):
|
85 |
+
return sqlite3.connect('news_app.db', check_same_thread=False)
|
86 |
+
|
87 |
+
def init_db(self):
|
88 |
+
current_version = self.get_db_version()
|
89 |
+
if current_version < 1:
|
90 |
+
self.c.execute('''CREATE TABLE IF NOT EXISTS sources
|
91 |
+
(id INTEGER PRIMARY KEY, url TEXT, title TEXT)''')
|
92 |
+
self.c.execute('''CREATE TABLE IF NOT EXISTS articles
|
93 |
+
(id INTEGER PRIMARY KEY, source_id INTEGER, url TEXT UNIQUE, title TEXT, date TEXT,
|
94 |
+
is_new INTEGER, is_excluded INTEGER DEFAULT 0)''')
|
95 |
+
self.c.execute('''CREATE TABLE IF NOT EXISTS user_actions
|
96 |
+
(id INTEGER PRIMARY KEY, article_id INTEGER, action TEXT, rating INTEGER, tags TEXT, timestamp TEXT)''')
|
97 |
+
self.c.execute('''CREATE TABLE IF NOT EXISTS tags
|
98 |
+
(id INTEGER PRIMARY KEY, name TEXT UNIQUE, description TEXT)''')
|
99 |
+
self.set_db_version(1)
|
100 |
+
|
101 |
+
# Add more version upgrades here
|
102 |
+
# if current_version < 2:
|
103 |
+
# self.c.execute('''ALTER TABLE articles ADD COLUMN new_column TEXT''')
|
104 |
+
# self.set_db_version(2)
|
105 |
+
|
106 |
+
self.conn.commit()
|
107 |
+
|
108 |
+
def get_db_version(self):
|
109 |
+
self.c.execute('''CREATE TABLE IF NOT EXISTS db_version (version INTEGER)''')
|
110 |
+
self.c.execute('SELECT version FROM db_version')
|
111 |
+
result = self.c.fetchone()
|
112 |
+
return result[0] if result else 0
|
113 |
+
|
114 |
+
def set_db_version(self, version):
|
115 |
+
self.c.execute('INSERT OR REPLACE INTO db_version (rowid, version) VALUES (1, ?)', (version,))
|
116 |
+
self.conn.commit()
|
117 |
+
|
118 |
+
def get_tabs(self):
|
119 |
+
return [{"name": t("scansite_title"), "plugin": "scansite"}]
|
120 |
+
|
121 |
+
def run(self, config):
|
122 |
+
st.title(t("scansite_title"))
|
123 |
+
|
124 |
+
total_links, annotated_links = self.get_stats()
|
125 |
+
st.write(f"{t('total_links')} : {total_links}")
|
126 |
+
st.write(f"{t('annotated_links')} : {annotated_links}")
|
127 |
+
|
128 |
+
all_tags = self.get_all_tags()
|
129 |
+
st.write(f"{t('known_tags')} :", ", ".join(all_tags))
|
130 |
+
|
131 |
+
if st.button(t("reset_database")):
|
132 |
+
self.reset_database()
|
133 |
+
st.success(t("database_reset_success"))
|
134 |
+
|
135 |
+
if st.button(t("launch_scan")):
|
136 |
+
self.launch_scan()
|
137 |
+
st.success(t("scan_complete"))
|
138 |
+
|
139 |
+
self.display_tabs()
|
140 |
+
|
141 |
+
def get_stats(self):
|
142 |
+
total_links = self.c.execute("SELECT COUNT(*) FROM articles WHERE is_excluded = 0").fetchone()[0]
|
143 |
+
annotated_links = self.c.execute("""
|
144 |
+
SELECT COUNT(DISTINCT article_id) FROM user_actions
|
145 |
+
WHERE action IN ('click', 'rate', 'tag')
|
146 |
+
""").fetchone()[0]
|
147 |
+
return total_links, annotated_links
|
148 |
+
|
149 |
+
def get_all_tags(self):
|
150 |
+
return [row[0] for row in self.c.execute("SELECT name FROM tags").fetchall()]
|
151 |
+
|
152 |
+
def reset_database(self):
|
153 |
+
self.c.execute("DROP TABLE IF EXISTS sources")
|
154 |
+
self.c.execute("DROP TABLE IF EXISTS articles")
|
155 |
+
self.c.execute("DROP TABLE IF EXISTS user_actions")
|
156 |
+
self.c.execute("DROP TABLE IF EXISTS tags")
|
157 |
+
self.conn.commit()
|
158 |
+
self.init_db()
|
159 |
+
|
160 |
+
def launch_scan(self):
|
161 |
+
sources = self.c.execute("SELECT * FROM sources").fetchall()
|
162 |
+
for source in sources:
|
163 |
+
self.mark_not_new(source[0])
|
164 |
+
links = self.scan_new_links(source[0], source[1])
|
165 |
+
for link, title in links:
|
166 |
+
self.c.execute("""
|
167 |
+
INSERT OR IGNORE INTO articles (source_id, url, title, date, is_new, is_excluded)
|
168 |
+
VALUES (?, ?, ?, ?, 1, 0)
|
169 |
+
""", (source[0], link, title, datetime.now().strftime('%Y-%m-%d')))
|
170 |
+
self.conn.commit()
|
171 |
+
|
172 |
+
def display_tabs(self):
|
173 |
+
tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([
|
174 |
+
t("new_articles"), t("rated_articles"), t("clicked_not_rated"),
|
175 |
+
t("tagged_articles"), t("ignored_articles"), t("excluded_articles")
|
176 |
+
])
|
177 |
+
|
178 |
+
all_tags = self.get_all_tags()
|
179 |
+
|
180 |
+
with tab1:
|
181 |
+
st.header(t("new_articles"))
|
182 |
+
self.display_paginated_articles(self.get_new_articles(), all_tags, "nouveaux")
|
183 |
+
|
184 |
+
with tab2:
|
185 |
+
st.header(t("rated_articles"))
|
186 |
+
self.display_paginated_articles(self.get_rated_articles(), all_tags, "notes")
|
187 |
+
|
188 |
+
with tab3:
|
189 |
+
st.header(t("clicked_not_rated"))
|
190 |
+
self.display_paginated_articles(self.get_clicked_not_rated_articles(), all_tags, "cliques")
|
191 |
+
|
192 |
+
with tab4:
|
193 |
+
st.header(t("tagged_articles"))
|
194 |
+
self.display_paginated_articles(self.get_tagged_articles(), all_tags, "tagues")
|
195 |
+
|
196 |
+
with tab5:
|
197 |
+
st.header(t("ignored_articles"))
|
198 |
+
self.display_paginated_articles(self.get_ignored_articles(), all_tags, "ignores")
|
199 |
+
|
200 |
+
with tab6:
|
201 |
+
st.header(t("excluded_articles"))
|
202 |
+
self.display_paginated_articles(self.get_excluded_articles(), all_tags, "exclus")
|
203 |
+
|
204 |
+
def display_paginated_articles(self, articles, all_tags, tab_name, items_per_page=20):
|
205 |
+
if not articles:
|
206 |
+
st.write(t("no_articles"))
|
207 |
+
return
|
208 |
+
|
209 |
+
total_pages = (len(articles) - 1) // items_per_page + 1
|
210 |
+
|
211 |
+
page_key = f"{tab_name}_page"
|
212 |
+
if page_key not in st.session_state:
|
213 |
+
st.session_state[page_key] = 1
|
214 |
+
|
215 |
+
page = st.number_input(t("page"), min_value=1, max_value=total_pages, value=st.session_state[page_key], key=f"{tab_name}_number_input")
|
216 |
+
st.session_state[page_key] = page
|
217 |
+
|
218 |
+
start_idx = (page - 1) * items_per_page
|
219 |
+
end_idx = start_idx + items_per_page
|
220 |
+
|
221 |
+
for article in articles[start_idx:end_idx]:
|
222 |
+
self.display_article(article, all_tags, tab_name)
|
223 |
+
|
224 |
+
col1, col2, col3 = st.columns(3)
|
225 |
+
with col1:
|
226 |
+
if page > 1:
|
227 |
+
if st.button(t("previous_page"), key=f"{tab_name}_prev"):
|
228 |
+
st.session_state[page_key] = page - 1
|
229 |
+
st.rerun()
|
230 |
+
with col3:
|
231 |
+
if page < total_pages:
|
232 |
+
if st.button(t("next_page"), key=f"{tab_name}_next"):
|
233 |
+
st.session_state[page_key] = page + 1
|
234 |
+
st.rerun()
|
235 |
+
with col2:
|
236 |
+
st.write(f"{t('page')} {page}/{total_pages}")
|
237 |
+
|
238 |
+
def display_article(self, article, all_tags, tab_name):
|
239 |
+
article_id = article[0]
|
240 |
+
|
241 |
+
col1, col2, col3, col4, col5 = st.columns([3, 0.5, 1, 2, 1])
|
242 |
+
|
243 |
+
with col1:
|
244 |
+
summary_key = f"{tab_name}_summary_{article_id}"
|
245 |
+
if summary_key not in st.session_state:
|
246 |
+
st.session_state[summary_key] = None
|
247 |
+
|
248 |
+
if st.button(article[3], key=f"{tab_name}_article_{article_id}"):
|
249 |
+
summary = self.get_article_summary(article[2])
|
250 |
+
st.session_state[summary_key] = summary
|
251 |
+
self.c.execute("INSERT INTO user_actions (article_id, action, timestamp) VALUES (?, ?, ?)",
|
252 |
+
(article_id, 'click', datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
253 |
+
self.c.execute("UPDATE articles SET is_new = 0 WHERE id = ?", (article_id,))
|
254 |
+
self.conn.commit()
|
255 |
+
|
256 |
+
if st.session_state[summary_key]:
|
257 |
+
st.write(st.session_state[summary_key])
|
258 |
+
|
259 |
+
with col2:
|
260 |
+
st.markdown(f"[🔗]({article[2]})")
|
261 |
+
|
262 |
+
with col3:
|
263 |
+
rating_key = f"{tab_name}_rating_{article_id}"
|
264 |
+
current_rating = self.get_article_rating(article_id)
|
265 |
+
rating = st.slider(t("rating"), 0, 5, current_rating, key=rating_key)
|
266 |
+
if rating != current_rating:
|
267 |
+
self.c.execute("INSERT INTO user_actions (article_id, action, rating, timestamp) VALUES (?, ?, ?, ?)",
|
268 |
+
(article_id, 'rate', rating, datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
269 |
+
self.conn.commit()
|
270 |
+
|
271 |
+
with col4:
|
272 |
+
tags_key = f"{tab_name}_tags_{article_id}"
|
273 |
+
current_tags = self.get_article_tags(article_id)
|
274 |
+
selected_tags = st.multiselect(t("tags"), all_tags, default=current_tags, key=tags_key)
|
275 |
+
if set(selected_tags) != set(current_tags):
|
276 |
+
tags_str = ','.join(selected_tags)
|
277 |
+
self.c.execute("INSERT INTO user_actions (article_id, action, tags, timestamp) VALUES (?, ?, ?, ?)",
|
278 |
+
(article_id, 'tag', tags_str, datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
279 |
+
self.conn.commit()
|
280 |
+
|
281 |
+
with col5:
|
282 |
+
exclude_key = f"{tab_name}_exclude_{article_id}"
|
283 |
+
if st.button(t("exclude"), key=exclude_key):
|
284 |
+
self.c.execute("UPDATE articles SET is_excluded = 1 WHERE id = ?", (article_id,))
|
285 |
+
self.conn.commit()
|
286 |
+
st.rerun()
|
287 |
+
|
288 |
+
def get_config_ui(self, config):
|
289 |
+
updated_config = {}
|
290 |
+
|
291 |
+
updated_config['sources'] = st.header(t("sources"))
|
292 |
+
sources = self.c.execute("SELECT * FROM sources").fetchall()
|
293 |
+
for source in sources:
|
294 |
+
col1, col2, col3 = st.columns([3, 1, 1])
|
295 |
+
with col1:
|
296 |
+
new_title = st.text_input(f"{t('update')} {source[1]}", value=source[2], key=f"source_title_{source[0]}")
|
297 |
+
with col2:
|
298 |
+
if st.button(t("update"), key=f"update_source_{source[0]}"):
|
299 |
+
self.c.execute("UPDATE sources SET title = ? WHERE id = ?", (new_title, source[0]))
|
300 |
+
self.conn.commit()
|
301 |
+
with col3:
|
302 |
+
if st.button(t("delete"), key=f"delete_source_{source[0]}"):
|
303 |
+
self.c.execute("DELETE FROM sources WHERE id = ?", (source[0],))
|
304 |
+
self.conn.commit()
|
305 |
+
|
306 |
+
new_url = st.text_input(t("add_new_source"))
|
307 |
+
if st.button(t("add_source")):
|
308 |
+
title = self.fetch_page_title(new_url)
|
309 |
+
self.c.execute("INSERT INTO sources (url, title) VALUES (?, ?)", (new_url, title))
|
310 |
+
self.conn.commit()
|
311 |
+
|
312 |
+
st.header(t("tags"))
|
313 |
+
tags = self.get_all_tags_with_descriptions()
|
314 |
+
for tag, description in tags:
|
315 |
+
col1, col2, col3, col4 = st.columns([2, 3, 1, 1])
|
316 |
+
with col1:
|
317 |
+
st.text(tag)
|
318 |
+
with col2:
|
319 |
+
new_description = st.text_input(f"{t('update')} {tag}", value=description, key=f"tag_desc_{tag}")
|
320 |
+
with col3:
|
321 |
+
if st.button(t("update"), key=f"update_tag_{tag}"):
|
322 |
+
self.add_or_update_tag(tag, new_description)
|
323 |
+
with col4:
|
324 |
+
if st.button(t("delete"), key=f"delete_tag_{tag}"):
|
325 |
+
self.delete_tag(tag)
|
326 |
+
|
327 |
+
new_tag = st.text_input(t("new_tag"))
|
328 |
+
new_tag_description = st.text_input(t("new_tag_description"))
|
329 |
+
if st.button(t("add_tag")):
|
330 |
+
self.add_or_update_tag(new_tag, new_tag_description)
|
331 |
+
|
332 |
+
# Ajout des configurations modifiées au dictionnaire updated_config
|
333 |
+
updated_config["sources"] = sources
|
334 |
+
updated_config["new_source_url"] = new_url
|
335 |
+
updated_config["tags"] = tags
|
336 |
+
updated_config["new_tag"] = new_tag
|
337 |
+
updated_config["new_tag_description"] = new_tag_description
|
338 |
+
|
339 |
+
return updated_config
|
340 |
+
|
341 |
+
def fetch_page_title(self, url):
|
342 |
+
try:
|
343 |
+
response = requests.get(url)
|
344 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
345 |
+
return soup.title.string
|
346 |
+
except:
|
347 |
+
return url
|
348 |
+
|
349 |
+
def mark_not_new(self, source_id):
|
350 |
+
self.c.execute("UPDATE articles SET is_new = 0 WHERE source_id = ?", (source_id,))
|
351 |
+
self.conn.commit()
|
352 |
+
|
353 |
+
def scan_new_links(self, source_id, url):
|
354 |
+
links = self.scan_links(url)
|
355 |
+
filtered_links = []
|
356 |
+
for link, title in links:
|
357 |
+
self.c.execute("SELECT id, is_excluded FROM articles WHERE url = ?", (link,))
|
358 |
+
result = self.c.fetchone()
|
359 |
+
if result is None:
|
360 |
+
filtered_links.append((link, title))
|
361 |
+
return filtered_links
|
362 |
+
|
363 |
+
def scan_links(self, url):
|
364 |
+
links = set()
|
365 |
+
try:
|
366 |
+
response = requests.get(url)
|
367 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
368 |
+
for link in soup.find_all('a'):
|
369 |
+
href = link.get('href')
|
370 |
+
title = link.text.strip() or href
|
371 |
+
if href and href.startswith('http'):
|
372 |
+
try:
|
373 |
+
article_response = requests.get(href)
|
374 |
+
article_soup = BeautifulSoup(article_response.text, 'html.parser')
|
375 |
+
if article_soup.find('article'):
|
376 |
+
links.add((href, title))
|
377 |
+
except:
|
378 |
+
pass
|
379 |
+
except:
|
380 |
+
st.error(f"Erreur lors du scan de {url}")
|
381 |
+
return list(links)
|
382 |
+
|
383 |
+
def get_article_summary(self, url, model="qwen2"):
|
384 |
+
prompt = f"Résumez brièvement l'article à cette URL : {url}"
|
385 |
+
response = ollama.generate(model=model, prompt=prompt)
|
386 |
+
return response['response']
|
387 |
+
|
388 |
+
def get_new_articles(self):
|
389 |
+
return self.c.execute("""
|
390 |
+
SELECT * FROM articles
|
391 |
+
WHERE is_new = 1
|
392 |
+
AND is_excluded = 0
|
393 |
+
AND id NOT IN (
|
394 |
+
SELECT DISTINCT article_id
|
395 |
+
FROM user_actions
|
396 |
+
WHERE action IN ('click', 'rate', 'tag')
|
397 |
+
)
|
398 |
+
ORDER BY date DESC
|
399 |
+
""").fetchall()
|
400 |
+
|
401 |
+
def get_rated_articles(self):
|
402 |
+
return self.c.execute("""
|
403 |
+
SELECT DISTINCT a.*
|
404 |
+
FROM articles a
|
405 |
+
JOIN user_actions ua ON a.id = ua.article_id
|
406 |
+
WHERE ua.action = 'rate'
|
407 |
+
AND a.is_excluded = 0
|
408 |
+
ORDER BY ua.timestamp DESC
|
409 |
+
""").fetchall()
|
410 |
+
|
411 |
+
def get_clicked_not_rated_articles(self):
|
412 |
+
return self.c.execute("""
|
413 |
+
SELECT DISTINCT a.*
|
414 |
+
FROM articles a
|
415 |
+
JOIN user_actions ua ON a.id = ua.article_id
|
416 |
+
WHERE ua.action = 'click'
|
417 |
+
AND a.is_excluded = 0
|
418 |
+
AND a.id NOT IN (
|
419 |
+
SELECT article_id
|
420 |
+
FROM user_actions
|
421 |
+
WHERE action IN ('rate', 'tag')
|
422 |
+
)
|
423 |
+
ORDER BY ua.timestamp DESC
|
424 |
+
""").fetchall()
|
425 |
+
|
426 |
+
def get_tagged_articles(self):
|
427 |
+
return self.c.execute("""
|
428 |
+
SELECT DISTINCT a.*
|
429 |
+
FROM articles a
|
430 |
+
JOIN user_actions ua ON a.id = ua.article_id
|
431 |
+
WHERE ua.action = 'tag'
|
432 |
+
AND a.is_excluded = 0
|
433 |
+
AND a.id NOT IN (
|
434 |
+
SELECT article_id
|
435 |
+
FROM user_actions
|
436 |
+
WHERE action IN ('rate', 'click')
|
437 |
+
)
|
438 |
+
ORDER BY ua.timestamp DESC
|
439 |
+
""").fetchall()
|
440 |
+
|
441 |
+
def get_ignored_articles(self):
|
442 |
+
return self.c.execute("""
|
443 |
+
SELECT * FROM articles
|
444 |
+
WHERE is_new = 0
|
445 |
+
AND is_excluded = 0
|
446 |
+
AND id NOT IN (
|
447 |
+
SELECT DISTINCT article_id
|
448 |
+
FROM user_actions
|
449 |
+
WHERE action IN ('click', 'rate', 'tag')
|
450 |
+
)
|
451 |
+
ORDER BY date DESC
|
452 |
+
""").fetchall()
|
453 |
+
|
454 |
+
def get_excluded_articles(self):
|
455 |
+
return self.c.execute("""
|
456 |
+
SELECT * FROM articles
|
457 |
+
WHERE is_excluded = 1
|
458 |
+
ORDER BY date DESC
|
459 |
+
""").fetchall()
|
460 |
+
|
461 |
+
def get_article_rating(self, article_id):
|
462 |
+
self.c.execute("SELECT rating FROM user_actions WHERE article_id = ? AND action = 'rate' ORDER BY timestamp DESC LIMIT 1", (article_id,))
|
463 |
+
result = self.c.fetchone()
|
464 |
+
return result[0] if result else 0
|
465 |
+
|
466 |
+
def get_article_tags(self, article_id):
|
467 |
+
self.c.execute("SELECT tags FROM user_actions WHERE article_id = ? AND action = 'tag' ORDER BY timestamp DESC LIMIT 1", (article_id,))
|
468 |
+
result = self.c.fetchone()
|
469 |
+
return result[0].split(',') if result and result[0] else []
|
470 |
+
|
471 |
+
def get_all_tags_with_descriptions(self):
|
472 |
+
return self.c.execute("SELECT name, description FROM tags").fetchall()
|
473 |
+
|
474 |
+
def add_or_update_tag(self, name, description):
|
475 |
+
self.c.execute("INSERT OR REPLACE INTO tags (name, description) VALUES (?, ?)", (name, description))
|
476 |
+
self.conn.commit()
|
477 |
+
|
478 |
+
def delete_tag(self, name):
|
479 |
+
self.c.execute("DELETE FROM tags WHERE name = ?", (name,))
|
480 |
+
self.conn.commit()
|
481 |
+
|
482 |
+
def get_reference_data(self):
|
483 |
+
# Récupérer les articles avec leur rating
|
484 |
+
self.c.execute("""
|
485 |
+
SELECT a.id, a.url, a.title, COALESCE(ua.rating, 0) as rating
|
486 |
+
FROM articles a
|
487 |
+
LEFT JOIN (
|
488 |
+
SELECT article_id, rating
|
489 |
+
FROM user_actions
|
490 |
+
WHERE action = 'rate'
|
491 |
+
GROUP BY article_id
|
492 |
+
HAVING MAX(timestamp)
|
493 |
+
) ua ON a.id = ua.article_id
|
494 |
+
WHERE a.is_excluded = 0
|
495 |
+
ORDER BY rating DESC, a.date DESC
|
496 |
+
""")
|
497 |
+
articles = self.c.fetchall()
|
498 |
+
|
499 |
+
# Séparer les articles en valides (notés) et rejetés (non notés)
|
500 |
+
reference_data_valid = [(article[1], article[2], article[3]) for article in articles if article[3] > 0]
|
501 |
+
reference_data_rejected = [(article[1], article[2]) for article in articles if article[3] == 0]
|
502 |
+
|
503 |
+
return reference_data_valid, reference_data_rejected
|
plugins/webrankings.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import requests
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
from transformers import pipeline
|
7 |
+
from sentence_transformers import SentenceTransformer, util
|
8 |
+
import concurrent.futures
|
9 |
+
import time
|
10 |
+
import sys
|
11 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
+
from transformers import AutoTokenizer, AutoModel
|
14 |
+
import numpy as np
|
15 |
+
from scipy import stats
|
16 |
+
from PyDictionary import PyDictionary
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
from scipy import stats
|
19 |
+
import litellm
|
20 |
+
import re
|
21 |
+
import sentencepiece
|
22 |
+
import random
|
23 |
+
from global_vars import t, translations
|
24 |
+
from app import Plugin
|
25 |
+
|
26 |
+
from embeddings_ft import finetune as finetune_embeddings
|
27 |
+
from bart_ft import finetune as finetune_bart
|
28 |
+
|
29 |
+
from webrankings_helper import *
|
30 |
+
from plugins.scansite import ScansitePlugin
|
31 |
+
|
32 |
+
#from data import reference_data_valid, reference_data_rejected
|
33 |
+
#reference_data = reference_data_valid + reference_data_rejected
|
34 |
+
|
35 |
+
# Ajout des traductions spécifiques à ce plugin
|
36 |
+
translations["en"].update({
|
37 |
+
"webrankings_title": "Comparative os sorter",
|
38 |
+
"clear_memory": "Clear Memory",
|
39 |
+
"enter_topic": "Enter the topic you're interested in (e.g. longevity):",
|
40 |
+
"use_keyword_expansion": "Use keyword expansion",
|
41 |
+
"test_content": "Also test link content in addition to titles",
|
42 |
+
"select_llm_models": "Select LLM models to use",
|
43 |
+
"select_zero_shot_models": "Select zero-shot models to use",
|
44 |
+
"select_embedding_models": "Select embedding models to use",
|
45 |
+
"analyze_button": "Analyze",
|
46 |
+
"loading_models": "Loading models and analyzing links...",
|
47 |
+
"expanded_keywords": "Expanded keywords:",
|
48 |
+
"analysis_completed": "Analysis completed in {:.2f} seconds",
|
49 |
+
"evaluation_results": "Evaluation results with optimal thresholds:",
|
50 |
+
"summary_table": "Summary table of scores",
|
51 |
+
"optimal_thresholds": "Optimal thresholds:",
|
52 |
+
"spearman_comparison": "Comparison of Spearman correlations",
|
53 |
+
"methods": "Methods",
|
54 |
+
"spearman_correlation": "Spearman correlation coefficient",
|
55 |
+
"results_for": "Results for {}",
|
56 |
+
"device_info": "Device used for inference: {}",
|
57 |
+
"finetune_bart_title": "BART Fine-tuning Interface",
|
58 |
+
"finetune_embeddings_title": "Embeddings Fine-tuning Interface",
|
59 |
+
})
|
60 |
+
|
61 |
+
translations["fr"].update({
|
62 |
+
"webrankings_title": "Analyseur de classeurs",
|
63 |
+
"clear_memory": "Vider la mémoire",
|
64 |
+
"enter_topic": "Entrez le sujet qui vous intéresse (ex: longévité):",
|
65 |
+
"use_keyword_expansion": "Utiliser l'expansion des mots-clés",
|
66 |
+
"test_content": "Tester aussi le contenu des liens en plus des titres",
|
67 |
+
"select_llm_models": "Sélectionnez les modèles LLM à utiliser",
|
68 |
+
"select_zero_shot_models": "Sélectionnez les modèles zero-shot à utiliser",
|
69 |
+
"select_embedding_models": "Sélectionnez les modèles d'embedding à utiliser",
|
70 |
+
"analyze_button": "Analyser",
|
71 |
+
"loading_models": "Chargement des modèles et analyse des liens...",
|
72 |
+
"expanded_keywords": "Mots-clés étendus :",
|
73 |
+
"analysis_completed": "Analyse terminée en {:.2f} secondes",
|
74 |
+
"evaluation_results": "Résultats d'évaluation avec les seuils optimaux :",
|
75 |
+
"summary_table": "Tableau récapitulatif des scores",
|
76 |
+
"optimal_thresholds": "Seuils optimaux :",
|
77 |
+
"spearman_comparison": "Comparaison des corrélations de Spearman",
|
78 |
+
"methods": "Méthodes",
|
79 |
+
"spearman_correlation": "Coefficient de corrélation de Spearman",
|
80 |
+
"results_for": "Résultats pour {}",
|
81 |
+
"device_info": "Dispositif utilisé pour l'inférence : {}",
|
82 |
+
"finetune_bart_title": "Interface de Fine-tuning BART",
|
83 |
+
"finetune_embeddings_title": "Interface de Fine-tuning des Embeddings",
|
84 |
+
})
|
85 |
+
|
86 |
+
|
87 |
+
# Liste des modèles LLM
|
88 |
+
llm_models = [] #["ollama/llama3", "ollama/llama3.1", "ollama/qwen2", "ollama/phi3:medium-128k", "ollama/openhermes"]
|
89 |
+
|
90 |
+
# Liste des modèles zero-shot
|
91 |
+
zero_shot_models = [
|
92 |
+
("facebook/bart-large-mnli", "facebook/bart-large-mnli"),
|
93 |
+
("bart-large-ft", "./bart-large-ft")
|
94 |
+
#("cross-encoder/nli-deberta-v3-base", "cross-encoder/nli-deberta-v3-base")
|
95 |
+
]
|
96 |
+
|
97 |
+
# Liste des modèles d'embedding
|
98 |
+
embedding_models = [
|
99 |
+
("paraphrase-MiniLM-L6-v2", "paraphrase-MiniLM-L6-v2"),
|
100 |
+
("all-MiniLM-L6-v2", "all-MiniLM-L6-v2"),
|
101 |
+
("nomic-embed-text-v1", "nomic-ai/nomic-embed-text-v1"),
|
102 |
+
("embeddings-ft", "./embeddings-ft")
|
103 |
+
]
|
104 |
+
|
105 |
+
|
106 |
+
class WebrankingsPlugin(Plugin):
|
107 |
+
def __init__(self, name, plugin_manager):
|
108 |
+
super().__init__(name, plugin_manager)
|
109 |
+
self.scansite_plugin = ScansitePlugin('scansite', plugin_manager)
|
110 |
+
|
111 |
+
def get_tabs(self):
|
112 |
+
return [
|
113 |
+
{"name": t("webrankings_title"), "plugin": "webrankings"}
|
114 |
+
]
|
115 |
+
|
116 |
+
def run(self, config):
|
117 |
+
tab1, tab2, tab3 = st.tabs([t("webrankings_title"), t("finetune_bart_title"), t("finetune_embeddings_title")])
|
118 |
+
reference_data_valid, reference_data_rejected = self.scansite_plugin.get_reference_data()
|
119 |
+
reference_data = reference_data_valid + [(url, title, 0) for url, title in reference_data_rejected]
|
120 |
+
|
121 |
+
with tab1:
|
122 |
+
st.title(t("webrankings_title"))
|
123 |
+
if st.button(t("clear_memory")):
|
124 |
+
torch.cuda.empty_cache()
|
125 |
+
torch.cuda.synchronize()
|
126 |
+
clear_globals()
|
127 |
+
reset_cuda_context()
|
128 |
+
|
129 |
+
topic = st.text_input(t("enter_topic"), value="longevity, health, healthspan, lifespan")
|
130 |
+
use_synonyms = st.checkbox(t("use_keyword_expansion"), value=False)
|
131 |
+
check_content = st.checkbox(t("test_content"), value=False)
|
132 |
+
|
133 |
+
selected_llm_models = st.multiselect(t("select_llm_models"), llm_models, default=llm_models)
|
134 |
+
selected_zero_shot_models = st.multiselect(t("select_zero_shot_models"), [m[0] for m in zero_shot_models], default=[m[0] for m in zero_shot_models])
|
135 |
+
selected_embedding_models = st.multiselect(t("select_embedding_models"), [m[0] for m in embedding_models], default=[m[0] for m in embedding_models])
|
136 |
+
|
137 |
+
if st.button(t("analyze_button")):
|
138 |
+
with st.spinner(t("loading_models")):
|
139 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
140 |
+
|
141 |
+
# Préparation des modèles
|
142 |
+
zero_shot_classifiers = {name: pipeline("zero-shot-classification", model=model, device=device)
|
143 |
+
for name, model in zero_shot_models if name in selected_zero_shot_models}
|
144 |
+
embedding_models_dict = {}
|
145 |
+
for name, model in embedding_models:
|
146 |
+
import os
|
147 |
+
if name == "embeddings-ft":
|
148 |
+
if os.path.exists('./embeddings-ft'):
|
149 |
+
embedding_models_dict[name] = SentenceTransformer('./embeddings-ft', trust_remote_code=True).to(device)
|
150 |
+
else:
|
151 |
+
embedding_models_dict[name] = SentenceTransformer(model, trust_remote_code=True).to(device)
|
152 |
+
bert_models = [AutoModel.from_pretrained('bert-base-uncased').to(device)]
|
153 |
+
tfidf_objects = [TfidfVectorizer()]
|
154 |
+
#release_vram(zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects)
|
155 |
+
|
156 |
+
# Expansion des mots-clés (utilisant le premier modèle LLM sélectionné)
|
157 |
+
if use_synonyms and selected_llm_models:
|
158 |
+
expanded_query = []
|
159 |
+
for word in topic.split():
|
160 |
+
expanded_query.extend(expand_keywords_llm(word, llm_model=selected_llm_models[0]))
|
161 |
+
expanded_query = " ".join(expanded_query)
|
162 |
+
st.write("Mots-clés étendus :", expanded_query)
|
163 |
+
else:
|
164 |
+
expanded_query = topic
|
165 |
+
|
166 |
+
start_time = time.time()
|
167 |
+
# Analyse pour chaque lien
|
168 |
+
results = []
|
169 |
+
for title, link,note in reference_data:
|
170 |
+
result = analyze_link(
|
171 |
+
title, link, topic, zero_shot_classifiers, embedding_models_dict,
|
172 |
+
expanded_query, selected_llm_models, check_content
|
173 |
+
)
|
174 |
+
if result is not None:
|
175 |
+
results.append(result)
|
176 |
+
end_time = time.time()
|
177 |
+
|
178 |
+
# Libération de la mémoire VRAM et des autres ressources
|
179 |
+
release_vram(zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects)
|
180 |
+
|
181 |
+
# Création du DataFrame avec tous les résultats
|
182 |
+
df = pd.DataFrame(results)
|
183 |
+
print(f"Analyse terminée en {end_time - start_time:.2f} secondes")
|
184 |
+
|
185 |
+
st.success(t("analysis_completed").format(end_time - start_time))
|
186 |
+
|
187 |
+
# Évaluation et affichage des résultats
|
188 |
+
evaluation_results = {}
|
189 |
+
optimal_thresholds = {}
|
190 |
+
for column in df.columns:
|
191 |
+
if column != "Titre":
|
192 |
+
method_scores = df.set_index("Titre")[column].to_dict()
|
193 |
+
|
194 |
+
optimal_threshold = find_optimal_threshold(
|
195 |
+
[item[0] for item in reference_data_valid],
|
196 |
+
[item[0] for item in reference_data_rejected],
|
197 |
+
method_scores
|
198 |
+
)
|
199 |
+
optimal_thresholds[column] = optimal_threshold
|
200 |
+
|
201 |
+
evaluation_results[column] = evaluate_ranking(
|
202 |
+
[item[0] for item in reference_data_valid],
|
203 |
+
[item[0] for item in reference_data_rejected],
|
204 |
+
method_scores,
|
205 |
+
optimal_threshold, False
|
206 |
+
)
|
207 |
+
|
208 |
+
# Affichage des résultats
|
209 |
+
st.write(t("evaluation_results"))
|
210 |
+
eval_df = pd.DataFrame(evaluation_results).T
|
211 |
+
st.dataframe(eval_df)
|
212 |
+
|
213 |
+
st.subheader(t("summary_table"))
|
214 |
+
st.dataframe(df)
|
215 |
+
|
216 |
+
st.write(t("optimal_thresholds"))
|
217 |
+
st.json(optimal_thresholds)
|
218 |
+
|
219 |
+
# Graphique de comparaison des corrélations de Spearman
|
220 |
+
spearman_scores = [results['spearman_correlation'] for results in evaluation_results.values()]
|
221 |
+
plt.figure(figsize=(15, 8))
|
222 |
+
plt.bar(evaluation_results.keys(), spearman_scores)
|
223 |
+
plt.title(t("spearman_comparison"))
|
224 |
+
plt.xlabel(t("methods"))
|
225 |
+
plt.ylabel(t("spearman_correlation"))
|
226 |
+
plt.xticks(rotation=90, ha='right')
|
227 |
+
plt.tight_layout()
|
228 |
+
st.pyplot(plt)
|
229 |
+
|
230 |
+
# Affichage des résultats pour chaque méthode
|
231 |
+
for column in df.columns:
|
232 |
+
if column != "Titre":
|
233 |
+
st.subheader(f"Résultats pour {column}")
|
234 |
+
df_method = df[["Titre", column]].sort_values(column, ascending=False)
|
235 |
+
threshold = find_optimal_threshold(
|
236 |
+
[item[0] for item in reference_data_valid],
|
237 |
+
[item[0] for item in reference_data_rejected],
|
238 |
+
df_method.set_index("Titre")[column].to_dict()
|
239 |
+
)
|
240 |
+
df_method = df_method[df_method[column] > threshold]
|
241 |
+
st.dataframe(df_method)
|
242 |
+
|
243 |
+
with tab2:
|
244 |
+
st.title(t("finetune_bart_title"))
|
245 |
+
num_epochs = st.number_input("Nombre d'époques", min_value=1, max_value=10, value=2)
|
246 |
+
lr = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-1, value=2e-5, format="%.6f", step=1e-5)
|
247 |
+
weight_decay = st.number_input("Poids de Décroissance (Weight Decay)", min_value=0.0, max_value=0.1, value=0.01, step=0.005)
|
248 |
+
batch_size = st.number_input("Taille du Batch", min_value=1, max_value=16, value=1)
|
249 |
+
start = st.slider("Score initial des données valides", min_value=0.0, max_value=1.0, value=0.9, step=0.01)
|
250 |
+
model_name = st.text_input("Nom du modèle", value='facebook/bart-large-mnli')
|
251 |
+
num_warmup_steps = st.number_input("Nombre d'étapes de Warmup", min_value=0, max_value=100, value=0)
|
252 |
+
|
253 |
+
# Bouton pour lancer le fine-tuning
|
254 |
+
if st.button("Lancer le fine-tuning"):
|
255 |
+
with st.spinner("Fine-tuning en cours..."):
|
256 |
+
finetune_bart(num_epochs=num_epochs, lr=lr, weight_decay=weight_decay,
|
257 |
+
batch_size=batch_size, model_name=model_name, output_model='./bart-large-ft',
|
258 |
+
num_warmup_steps=num_warmup_steps)
|
259 |
+
st.success("Fine-tuning terminé et modèle sauvegardé.")
|
260 |
+
|
261 |
+
with tab3:
|
262 |
+
st.title(t("finetune_embeddings_title"))
|
263 |
+
num_epochs_emb = st.number_input("Nombre d'époques (Embeddings)", min_value=1, max_value=100, value=10)
|
264 |
+
lr_emb = st.number_input("Learning Rate (Embeddings)", min_value=1e-6, max_value=1e-1, value=2e-5, format="%.6f", step=5e-6)
|
265 |
+
weight_decay_emb = st.number_input("Poids de Décroissance (Weight Decay) (Embeddings)", min_value=0.0, max_value=0.1, value=0.01, step=0.005)
|
266 |
+
batch_size_emb = st.number_input("Taille du Batch (Embeddings)", min_value=1, max_value=32, value=16)
|
267 |
+
start_emb = st.slider("Score initial des données valides (Embeddings)", min_value=0.0, max_value=1.0, value=0.9, step=0.01)
|
268 |
+
model_name_emb = st.selectbox("Modèle d'embeddings de base", ["nomic-ai/nomic-embed-text-v1", "all-MiniLM-L6-v2", "paraphrase-MiniLM-L6-v2"])
|
269 |
+
margin_erb = st.slider("Marge (Embeddings)", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
|
270 |
+
|
271 |
+
# Bouton pour lancer le fine-tuning des embeddings
|
272 |
+
if st.button("Lancer le fine-tuning des embeddings"):
|
273 |
+
with st.spinner("Fine-tuning des embeddings en cours..."):
|
274 |
+
|
275 |
+
finetune_embeddings(model_name=model_name_emb, output_model_name="./embeddings-ft",
|
276 |
+
num_epochs=num_epochs_emb,
|
277 |
+
learning_rate=lr_emb,
|
278 |
+
weight_decay=weight_decay_emb,
|
279 |
+
batch_size=batch_size_emb,
|
280 |
+
)
|
281 |
+
st.success("Fine-tuning des embeddings terminé et modèle sauvegardé.")
|
282 |
+
|
283 |
+
# Affichage de l'information sur le dispositif utilisé
|
284 |
+
device = "GPU (CUDA)" if torch.cuda.is_available() else "CPU"
|
285 |
+
st.sidebar.info(t("device_info").format(device))
|
webrankings_helper.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
from transformers import pipeline
|
6 |
+
from sentence_transformers import SentenceTransformer, util
|
7 |
+
import concurrent.futures
|
8 |
+
import time
|
9 |
+
import sys
|
10 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
11 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
12 |
+
from transformers import AutoTokenizer, AutoModel
|
13 |
+
import numpy as np
|
14 |
+
from scipy import stats
|
15 |
+
from PyDictionary import PyDictionary
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from scipy import stats
|
18 |
+
import litellm
|
19 |
+
import re
|
20 |
+
import sentencepiece
|
21 |
+
import random
|
22 |
+
|
23 |
+
def score_with_llm(title, topic, llm_model):
|
24 |
+
prompt = f"""Evaluate the relevance of the following article to the topic '{topic}'.
|
25 |
+
Article title: {title}
|
26 |
+
Give a final relevance score between 0 and 1, where 1 is very relevant and 0 is not relevant at all.
|
27 |
+
Respond only with a number between 0 and 1."""
|
28 |
+
|
29 |
+
try:
|
30 |
+
response = litellm.completion(
|
31 |
+
model=llm_model,
|
32 |
+
messages=[{"role": "user", "content": prompt}],
|
33 |
+
max_tokens=10
|
34 |
+
)
|
35 |
+
|
36 |
+
score_match = re.search(r'\d+(\.\d+)?', response.choices[0].message.content.strip())
|
37 |
+
if score_match:
|
38 |
+
score = float(score_match.group())
|
39 |
+
print(f"Score LLM : {score}")
|
40 |
+
return max(0, min(score, 1))
|
41 |
+
else:
|
42 |
+
print(f"Could not extract a score from LLM response: {response.choices[0].message.content}")
|
43 |
+
return None
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error in scoring with LLM {llm_model}: {str(e)}")
|
46 |
+
return None
|
47 |
+
|
48 |
+
def expand_keywords_llm(keyword, max_synonyms=3, llm_model="ollama/qwen2"):
|
49 |
+
prompt = f"""Please provide up to {max_synonyms} synonyms or closely related terms for the word or phrase: "{keyword}".
|
50 |
+
Return only the list of synonyms, separated by commas, without any additional explanation."""
|
51 |
+
|
52 |
+
try:
|
53 |
+
response = litellm.completion(
|
54 |
+
model=llm_model,
|
55 |
+
messages=[{"role": "user", "content": prompt}],
|
56 |
+
max_tokens=50
|
57 |
+
)
|
58 |
+
|
59 |
+
synonyms = [s.strip() for s in response.choices[0].message.content.split(',')]
|
60 |
+
return [keyword] + synonyms[:max_synonyms]
|
61 |
+
except Exception as e:
|
62 |
+
print(f"Error in expanding keywords with LLM {llm_model}: {str(e)}")
|
63 |
+
return [keyword]
|
64 |
+
|
65 |
+
# Fonction pour obtenir les liens de la page d'accueil
|
66 |
+
def get_homepage_links(url):
|
67 |
+
response = requests.get(url)
|
68 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
69 |
+
links = soup.find_all('a', href=True)
|
70 |
+
return [(link.text.strip(), link['href']) for link in links if link.text.strip()]
|
71 |
+
|
72 |
+
# Fonction pour obtenir le contenu d'un article
|
73 |
+
def get_article_content(url):
|
74 |
+
try:
|
75 |
+
print(f"Récupération du contenu de : {url}")
|
76 |
+
response = requests.get(url)
|
77 |
+
print(f"Taille de la réponse HTTP : {len(response.content)} octets") # Affiche le nombre d'octets de la réponse HTTP
|
78 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
79 |
+
print(f"Taille de l'objet soup : {sys.getsizeof(soup)} octets") # Affiche la taille en mémoire de l'objet soup
|
80 |
+
article = soup.find('article')
|
81 |
+
if article:
|
82 |
+
paragraphs = article.find_all('p')
|
83 |
+
content = ' '.join([p.text for p in paragraphs])
|
84 |
+
print(f"Paragraphes récupéré : {len(content)} caractères")
|
85 |
+
return content
|
86 |
+
print("Aucun contenu d'article trouvé")
|
87 |
+
return ""
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Erreur lors de la récupération du contenu : {str(e)}")
|
90 |
+
return ""
|
91 |
+
|
92 |
+
# Fonction pour l'analyse zero-shot
|
93 |
+
def zero_shot_analysis(text, topic, classifier):
|
94 |
+
if not text:
|
95 |
+
print("Texte vide pour l'analyse zero-shot")
|
96 |
+
return 0.0
|
97 |
+
result = classifier(text, candidate_labels=[topic, f"not {topic}"], multi_label=False)
|
98 |
+
print(f"Score zero-shot : {result['scores'][0]}")
|
99 |
+
return result['scores'][0]
|
100 |
+
|
101 |
+
# Fonction pour l'analyse par embeddings
|
102 |
+
def embedding_analysis(text, topic_embedding, model):
|
103 |
+
if not text:
|
104 |
+
print("Texte vide pour l'analyse par embeddings")
|
105 |
+
return 0.0
|
106 |
+
text_embedding = model.encode([text], convert_to_tensor=True)
|
107 |
+
similarity = util.pytorch_cos_sim(text_embedding, topic_embedding).item()
|
108 |
+
print(f"Score embedding : {similarity}")
|
109 |
+
return similarity
|
110 |
+
|
111 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
112 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
113 |
+
#import nltk
|
114 |
+
#from nltk.corpus import wordnet
|
115 |
+
#nltk.download('wordnet')
|
116 |
+
|
117 |
+
def preprocess_text(text):
|
118 |
+
# Tokenize the text
|
119 |
+
tokens = text.lower().split()
|
120 |
+
# Expand each token with its synonyms
|
121 |
+
expanded_tokens = []
|
122 |
+
for token in tokens:
|
123 |
+
synonyms = set()
|
124 |
+
for syn in wordnet.synsets(token):
|
125 |
+
for lemma in syn.lemmas():
|
126 |
+
synonyms.add(lemma.name())
|
127 |
+
expanded_tokens.extend(list(synonyms))
|
128 |
+
return ' '.join(expanded_tokens)
|
129 |
+
|
130 |
+
def improved_tfidf_similarity(texts, query):
|
131 |
+
# Preprocess texts and query
|
132 |
+
preprocessed_texts = [preprocess_text(text) for text in texts]
|
133 |
+
preprocessed_query = preprocess_text(query)
|
134 |
+
|
135 |
+
# Combine texts and query for vectorization
|
136 |
+
all_texts = preprocessed_texts + [preprocessed_query]
|
137 |
+
|
138 |
+
# Use TfidfVectorizer with custom parameters
|
139 |
+
vectorizer = TfidfVectorizer(ngram_range=(1, 2), min_df=1, smooth_idf=True)
|
140 |
+
tfidf_matrix = vectorizer.fit_transform(all_texts)
|
141 |
+
|
142 |
+
# Calculate cosine similarity
|
143 |
+
cosine_similarities = cosine_similarity(tfidf_matrix[-1], tfidf_matrix[:-1]).flatten()
|
144 |
+
|
145 |
+
# Normalize similarities to avoid zero scores
|
146 |
+
normalized_similarities = (cosine_similarities - cosine_similarities.min()) / (cosine_similarities.max() - cosine_similarities.min())
|
147 |
+
|
148 |
+
return normalized_similarities
|
149 |
+
|
150 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
151 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
152 |
+
import numpy as np
|
153 |
+
|
154 |
+
def improved_tfidf_similarity_v2(texts, query):
|
155 |
+
# Combine texts and query, treating each word or phrase as a separate document
|
156 |
+
all_docs = [word.strip() for text in texts for word in text.split(',')] + [word.strip() for word in query.split(',')]
|
157 |
+
|
158 |
+
# Create TF-IDF matrix
|
159 |
+
vectorizer = TfidfVectorizer()
|
160 |
+
tfidf_matrix = vectorizer.fit_transform(all_docs)
|
161 |
+
|
162 |
+
# Calculate document vectors by summing the TF-IDF vectors of their words
|
163 |
+
doc_vectors = []
|
164 |
+
query_vector = np.zeros((1, tfidf_matrix.shape[1]))
|
165 |
+
|
166 |
+
current_doc = 0
|
167 |
+
for i, doc in enumerate(all_docs):
|
168 |
+
if i < len(all_docs) - len(query.split(',')): # If it's part of the texts
|
169 |
+
if current_doc == len(texts):
|
170 |
+
break
|
171 |
+
if doc in texts[current_doc]:
|
172 |
+
doc_vectors.append(tfidf_matrix[i].toarray())
|
173 |
+
else:
|
174 |
+
current_doc += 1
|
175 |
+
doc_vectors.append(tfidf_matrix[i].toarray())
|
176 |
+
else: # If it's part of the query
|
177 |
+
query_vector += tfidf_matrix[i].toarray()
|
178 |
+
|
179 |
+
doc_vectors = np.array([np.sum(doc, axis=0) for doc in doc_vectors])
|
180 |
+
|
181 |
+
# Calculate cosine similarity
|
182 |
+
similarities = cosine_similarity(query_vector, doc_vectors).flatten()
|
183 |
+
|
184 |
+
# Normalize similarities to avoid zero scores
|
185 |
+
normalized_similarities = (similarities - similarities.min()) / (similarities.max() - similarities.min() + 1e-8)
|
186 |
+
|
187 |
+
return normalized_similarities
|
188 |
+
|
189 |
+
# Example usage:
|
190 |
+
# texts = ["longevity, health, aging", "computer science, AI"]
|
191 |
+
# query = "longevity, life extension, anti-aging"
|
192 |
+
# results = improved_tfidf_similarity_v2(texts, query)
|
193 |
+
# print(results)
|
194 |
+
|
195 |
+
# Nouvelles fonctions
|
196 |
+
def tfidf_similarity(texts, query):
|
197 |
+
vectorizer = TfidfVectorizer()
|
198 |
+
tfidf_matrix = vectorizer.fit_transform(texts + [query])
|
199 |
+
cosine_similarities = cosine_similarity(tfidf_matrix[-1], tfidf_matrix[:-1]).flatten()
|
200 |
+
return cosine_similarities
|
201 |
+
|
202 |
+
def bert_similarity(texts, query, model_name='bert-base-uncased'):
|
203 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
204 |
+
model = AutoModel.from_pretrained(model_name)
|
205 |
+
|
206 |
+
def get_embedding(text):
|
207 |
+
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
208 |
+
with torch.no_grad():
|
209 |
+
outputs = model(**inputs)
|
210 |
+
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
|
211 |
+
|
212 |
+
query_embedding = get_embedding(query)
|
213 |
+
text_embeddings = [get_embedding(text) for text in texts]
|
214 |
+
|
215 |
+
similarities = [cosine_similarity([query_embedding], [text_embedding])[0][0] for text_embedding in text_embeddings]
|
216 |
+
return similarities
|
217 |
+
|
218 |
+
# Fonction principale d'analyse modifiée
|
219 |
+
def analyze_link(title, link, topic, zero_shot_classifiers, embedding_models, expanded_query, llm_models, testcontent):
|
220 |
+
print(f"\nAnalyse de : {title}")
|
221 |
+
|
222 |
+
results = {
|
223 |
+
"Titre": title,
|
224 |
+
#"TF-IDF (titre)": improved_tfidf_similarity_v2([title], expanded_query)[0],
|
225 |
+
#"BERT (titre)": bert_similarity([title], expanded_query)[0],
|
226 |
+
}
|
227 |
+
|
228 |
+
# Zero-shot analysis
|
229 |
+
for name, classifier in zero_shot_classifiers.items():
|
230 |
+
results[f"Zero-shot (titre) - {name}"] = zero_shot_analysis(title, topic, classifier)
|
231 |
+
|
232 |
+
# Embedding analysis
|
233 |
+
for name, model in embedding_models.items():
|
234 |
+
topic_embedding = model.encode([expanded_query], convert_to_tensor=True)
|
235 |
+
results[f"Embeddings (titre) - {name}"] = embedding_analysis(title, topic_embedding, model)
|
236 |
+
|
237 |
+
# LLM analysis
|
238 |
+
for model in llm_models:
|
239 |
+
results[f"LLM Score - {model}"] = score_with_llm(title, topic, model)
|
240 |
+
|
241 |
+
if testcontent:
|
242 |
+
content = get_article_content(link)
|
243 |
+
#results["TF-IDF (contenu)"] = improved_tfidf_similarity_v2([content], expanded_query)[0]
|
244 |
+
#results["BERT (contenu)"]= bert_similarity([content], expanded_query)[0]
|
245 |
+
|
246 |
+
# Zero-shot analysis
|
247 |
+
for name, classifier in zero_shot_classifiers.items():
|
248 |
+
results[f"Zero-shot (contenu) - {name}"] = zero_shot_analysis(content, topic, classifier)
|
249 |
+
|
250 |
+
# Embedding analysis
|
251 |
+
for name, model in embedding_models.items():
|
252 |
+
topic_embedding = model.encode([expanded_query], convert_to_tensor=True)
|
253 |
+
results[f"Embeddings (contenu) - {name}"] = embedding_analysis(content, topic_embedding, model)
|
254 |
+
|
255 |
+
# LLM analysis
|
256 |
+
for model in llm_models:
|
257 |
+
results[f"LLM Content Score - {model}"] = score_with_llm(content, topic, model)
|
258 |
+
|
259 |
+
|
260 |
+
return results
|
261 |
+
|
262 |
+
from scipy import stats
|
263 |
+
|
264 |
+
def evaluate_ranking(reference_data_valid, reference_data_rejected, method_scores, threshold, silent):
|
265 |
+
simple_score = 0
|
266 |
+
true_positives = 0
|
267 |
+
false_positives = 0
|
268 |
+
true_negatives = 0
|
269 |
+
false_negatives = 0
|
270 |
+
|
271 |
+
# Créer une liste de tous les éléments avec leur statut (1 pour valide, 0 pour rejeté)
|
272 |
+
all_items = [(item, 1) for item in reference_data_valid] + [(item, 0) for item in reference_data_rejected]
|
273 |
+
|
274 |
+
# Trier les éléments selon leur score dans la méthode
|
275 |
+
all_items_temp = all_items.copy()
|
276 |
+
# correct false positive if method spit out same score for all
|
277 |
+
#random.shuffle(all_items_temp)
|
278 |
+
all_items_temp.reverse()
|
279 |
+
sorted_method = sorted([(item, method_scores.get(item, 0)) for item, _ in all_items_temp],
|
280 |
+
key=lambda x: x[1], reverse=True)
|
281 |
+
|
282 |
+
# Créer des listes pour le calcul de la corrélation de Spearman
|
283 |
+
reference_ranks = []
|
284 |
+
method_ranks = []
|
285 |
+
|
286 |
+
for i, (item, status) in enumerate(all_items):
|
287 |
+
method_score = method_scores.get(item, 0)
|
288 |
+
method_rank = next(j for j, (it, score) in enumerate(sorted_method) if it == item)
|
289 |
+
|
290 |
+
reference_ranks.append(i)
|
291 |
+
method_ranks.append(method_rank)
|
292 |
+
|
293 |
+
if status == 1: # Item valide
|
294 |
+
if method_score >= threshold:
|
295 |
+
simple_score += 1
|
296 |
+
true_positives += 1
|
297 |
+
else:
|
298 |
+
simple_score -= 1
|
299 |
+
false_negatives += 1
|
300 |
+
else: # Item rejeté
|
301 |
+
if method_score < threshold:
|
302 |
+
simple_score += 1
|
303 |
+
true_negatives += 1
|
304 |
+
else:
|
305 |
+
simple_score -= 1
|
306 |
+
false_positives += 1
|
307 |
+
|
308 |
+
# Calculer le coefficient de corrélation de Spearman
|
309 |
+
if not silent:
|
310 |
+
print("+++")
|
311 |
+
print(reference_ranks)
|
312 |
+
print("---")
|
313 |
+
print(method_ranks)
|
314 |
+
spearman_corr, _ = stats.spearmanr(reference_ranks, method_ranks)
|
315 |
+
|
316 |
+
# Calculer la précision, le rappel et le F1-score
|
317 |
+
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
318 |
+
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
|
319 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
320 |
+
|
321 |
+
return {
|
322 |
+
"simple_score": simple_score,
|
323 |
+
"spearman_correlation": spearman_corr,
|
324 |
+
"precision": precision,
|
325 |
+
"recall": recall,
|
326 |
+
"f1_score": f1_score,
|
327 |
+
}
|
328 |
+
|
329 |
+
def find_optimal_threshold(reference_data_valid, reference_data_rejected, method_scores):
|
330 |
+
best_score = float('-inf')
|
331 |
+
best_threshold = 0
|
332 |
+
|
333 |
+
for threshold in np.arange(0, 1.05, 0.05):
|
334 |
+
result = evaluate_ranking(
|
335 |
+
reference_data_valid,
|
336 |
+
reference_data_rejected,
|
337 |
+
method_scores,
|
338 |
+
threshold, True
|
339 |
+
)
|
340 |
+
if result['simple_score'] > best_score:
|
341 |
+
best_score = result['simple_score']
|
342 |
+
best_threshold = threshold
|
343 |
+
|
344 |
+
return best_threshold
|
345 |
+
|
346 |
+
def reset_cuda_context():
|
347 |
+
torch.cuda.empty_cache()
|
348 |
+
torch.cuda.ipc_collect()
|
349 |
+
if torch.cuda.is_available():
|
350 |
+
torch.cuda.set_device(torch.cuda.current_device())
|
351 |
+
torch.cuda.synchronize()
|
352 |
+
|
353 |
+
import gc
|
354 |
+
def clear_models():
|
355 |
+
global zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects
|
356 |
+
|
357 |
+
for classifier in zero_shot_classifiers.values():
|
358 |
+
del classifier
|
359 |
+
zero_shot_classifiers.clear()
|
360 |
+
|
361 |
+
for model in embedding_models_dict.values():
|
362 |
+
del model
|
363 |
+
embedding_models_dict.clear()
|
364 |
+
|
365 |
+
for model in bert_models:
|
366 |
+
del model
|
367 |
+
bert_models.clear()
|
368 |
+
|
369 |
+
for vectorizer in tfidf_objects:
|
370 |
+
del vectorizer
|
371 |
+
tfidf_objects.clear()
|
372 |
+
|
373 |
+
torch.cuda.empty_cache()
|
374 |
+
gc.collect()
|
375 |
+
|
376 |
+
def clear_globals():
|
377 |
+
for name in list(globals()):
|
378 |
+
if isinstance(globals()[name], (torch.nn.Module, torch.Tensor)):
|
379 |
+
del globals()[name]
|
380 |
+
|
381 |
+
def release_vram(zero_shot_classifiers, embedding_models, bert_models, tfidf_objects):
|
382 |
+
# Supprimer les objets zero-shot classifiers
|
383 |
+
for model in zero_shot_classifiers.values():
|
384 |
+
del model
|
385 |
+
|
386 |
+
# Supprimer les objets embedding models
|
387 |
+
for model in embedding_models.values():
|
388 |
+
del model
|
389 |
+
|
390 |
+
# Supprimer les objets bert models
|
391 |
+
for model in bert_models:
|
392 |
+
del model
|
393 |
+
|
394 |
+
# Supprimer les objets tfidf objects
|
395 |
+
for obj in tfidf_objects:
|
396 |
+
del obj
|
397 |
+
|
398 |
+
# Vider le cache de la mémoire GPU
|
399 |
+
torch.cuda.empty_cache()
|
400 |
+
torch.cuda.synchronize()
|
401 |
+
gc.collect()
|
402 |
+
clear_globals()
|
403 |
+
reset_cuda_context()
|
404 |
+
|
405 |
+
def load_finetuned_model(model_path):
|
406 |
+
checkpoint = torch.load(model_path)
|
407 |
+
base_model = AutoModel.from_pretrained(checkpoint['base_model_name'])
|
408 |
+
model = EmbeddingModel(base_model)
|
409 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
410 |
+
return model
|