johannoriel commited on
Commit
f34a6fd
·
1 Parent(s): 864cca3

Initial relase. Tested. Working

Browse files
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import importlib
4
+ import streamlit as st
5
+ from typing import List, Dict, Any
6
+ from dotenv import load_dotenv
7
+ from global_vars import translations, t
8
+
9
+ # Constantes
10
+ CONFIG_FILE = "config.json"
11
+
12
+ def load_config() -> Dict[str, Any]:
13
+ if os.path.exists(CONFIG_FILE):
14
+ with open(CONFIG_FILE, 'r') as f:
15
+ return json.load(f)
16
+ return {}
17
+
18
+ def save_config(config: Dict[str, Any]):
19
+ with open(CONFIG_FILE, 'w') as f:
20
+ json.dump(config, f, indent=2)
21
+
22
+ # Fonction pour mettre à jour la langue
23
+ def set_lang(language):
24
+ st.session_state.lang = language
25
+
26
+ # Fonction de traduction
27
+ def t(key: str) -> str:
28
+ return translations[st.session_state.lang].get(key, key)
29
+
30
+ class Plugin:
31
+ def __init__(self, name, plugin_manager):
32
+ self.name = name
33
+ self.plugin_manager = plugin_manager
34
+
35
+ def get_config_fields(self) -> Dict[str, Any]:
36
+ return {}
37
+
38
+ def get_config_ui(self, config):
39
+ updated_config = {}
40
+ for field, params in self.get_config_fields().items():
41
+ if params['type'] == 'select':
42
+ updated_config[field] = st.selectbox(
43
+ params['label'],
44
+ options=[option[0] for option in params['options']],
45
+ format_func=lambda x: dict(params['options'])[x],
46
+ index=[option[0] for option in params['options']].index(config.get(field, params['default']))
47
+ )
48
+ elif params['type'] == 'textarea':
49
+ updated_config[field] = st.text_area(
50
+ params['label'],
51
+ value=config.get(field, params['default'])
52
+ )
53
+ else:
54
+ updated_config[field] = st.text_input(
55
+ params['label'],
56
+ value=config.get(field, params['default']),
57
+ type="password" if field.startswith("pass") else "default"
58
+ )
59
+ return updated_config
60
+
61
+ def get_tabs(self) -> List[Dict[str, Any]]:
62
+ return []
63
+
64
+ def run(self, config: Dict[str, Any]):
65
+ pass
66
+
67
+ def get_sidebar_config_ui(self, config: Dict[str, Any]) -> Dict[str, Any]:
68
+ return {}
69
+
70
+ class PluginManager:
71
+ def __init__(self):
72
+ self.plugins: Dict[str, Plugin] = {}
73
+ self.starred_plugins: Set[str] = set()
74
+
75
+ def load_plugins(self):
76
+ plugins_dir = 'plugins'
77
+ for filename in os.listdir(plugins_dir):
78
+ if filename.endswith('.py'):
79
+ module_name = filename[:-3]
80
+ module = importlib.import_module(f'plugins.{module_name}')
81
+ plugin_class = getattr(module, f'{module_name.capitalize()}Plugin')
82
+ self.plugins[module_name] = plugin_class(module_name, self)
83
+
84
+ def get_plugin(self, plugin_name: str) -> Plugin:
85
+ return self.plugins.get(plugin_name)
86
+
87
+ def get_all_config_ui(self, config: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
88
+ all_ui = {}
89
+ for plugin_name, plugin in sorted(self.plugins.items()):
90
+ with st.expander(f"{'⭐ ' if plugin_name in self.starred_plugins else ''}{t('configurations')} {plugin_name}"):
91
+ all_ui[plugin_name] = plugin.get_config_ui(config.get(plugin_name, {}))
92
+ if st.button(f"{'Unstar' if plugin_name in self.starred_plugins else 'Star'} {plugin_name}"):
93
+ if plugin_name in self.starred_plugins:
94
+ self.starred_plugins.remove(plugin_name)
95
+ else:
96
+ self.starred_plugins.add(plugin_name)
97
+ self.save_starred_plugins(config)
98
+ st.rerun()
99
+ return all_ui
100
+
101
+ def get_all_tabs(self) -> List[Dict[str, Any]]:
102
+ all_tabs = []
103
+ for plugin_name, plugin in sorted(self.plugins.items()):
104
+ tabs = plugin.get_tabs()
105
+ for tab in tabs:
106
+ tab['id'] = plugin_name
107
+ tab['starred'] = plugin_name in self.starred_plugins
108
+ all_tabs.extend(tabs)
109
+ return all_tabs
110
+
111
+
112
+ def load_starred_plugins(self, config: Dict[str, Any]):
113
+ self.starred_plugins = set(config.get('starred_plugins', []))
114
+
115
+ def save_starred_plugins(self, config: Dict[str, Any]):
116
+ config['starred_plugins'] = list(self.starred_plugins)
117
+ save_config(config)
118
+
119
+ def run_plugin(self, plugin_name: str, config: Dict[str, Any]):
120
+ if plugin_name in self.plugins:
121
+ self.plugins[plugin_name].run(config)
122
+
123
+ def save_config(self, config):
124
+ save_config(config)
125
+
126
+ def main():
127
+ st.set_page_config(page_title="Veille", layout="wide")
128
+ # Initialisation du gestionnaire de plugins
129
+ plugin_manager = PluginManager()
130
+ plugin_manager.load_plugins()
131
+
132
+ # Chargement de la configuration
133
+ config = load_config()
134
+ plugin_manager.load_starred_plugins(config)
135
+
136
+ # Initialisation de la langue dans st.session_state
137
+ if 'lang' not in st.session_state:
138
+ st.session_state.lang = config['common']['language']
139
+ st.title(t("page_title"))
140
+
141
+ load_dotenv()
142
+ LLM_KEY = os.getenv("LLM_API_KEY")
143
+ config['llm_key'] = LLM_KEY
144
+
145
+ # Création des onglets avec des identifiants uniques
146
+ tabs = [{"id": "configurations", "name": t("configurations")}] + [{"id": tab['plugin'], "name": tab['name'], "starred" : tab['starred']} for tab in plugin_manager.get_all_tabs()]
147
+
148
+ # Gestion de la langue
149
+ if 'lang' not in st.session_state:
150
+ st.session_state.lang = "fr"
151
+
152
+ new_lang = st.sidebar.selectbox("Choose your language / Choisissez votre langue", options=["en", "fr"], index=["en", "fr"].index(st.session_state.lang), key="lang_selector")
153
+
154
+ if new_lang != st.session_state.lang:
155
+ st.session_state.lang = new_lang
156
+ st.rerun()
157
+
158
+ # Ajout des éléments de configuration de la sidebar pour chaque plugin
159
+ for plugin_name, plugin in plugin_manager.plugins.items():
160
+ sidebar_config = plugin.get_sidebar_config_ui(config.get(plugin_name, {}))
161
+ if sidebar_config:
162
+ #st.sidebar.markdown(f"**{plugin_name} Configuration**")
163
+ for key, value in sidebar_config.items():
164
+ config.setdefault(plugin_name, {})[key] = value
165
+
166
+ # Gestion de l'onglet sélectionné
167
+ if 'selected_tab_id' not in st.session_state:
168
+ st.session_state.selected_tab_id = "configurations"
169
+
170
+ # Sort tabs alphabetically, with starred tabs first
171
+ sorted_tabs = sorted(tabs, key=lambda x: (not x.get('starred', False), x['name']))
172
+ tab_names = [f"{'⭐ ' if tab.get('starred', False) else ''}{tab['name']}" for tab in sorted_tabs]
173
+
174
+ selected_tab_index = [tab["id"] for tab in sorted_tabs].index(st.session_state.selected_tab_id)
175
+ selected_tab = st.sidebar.radio(t("navigation"), tab_names, index=selected_tab_index, key="tab_selector")
176
+
177
+ new_selected_tab_id = next(tab["id"] for tab in sorted_tabs if f"{'⭐ ' if tab.get('starred', False) else ''}{tab['name']}" == selected_tab)
178
+
179
+ if new_selected_tab_id != st.session_state.selected_tab_id:
180
+ st.session_state.selected_tab_id = new_selected_tab_id
181
+ st.rerun()
182
+
183
+ if st.session_state.selected_tab_id == "configurations":
184
+ st.header(t("configurations"))
185
+ all_config_ui = plugin_manager.get_all_config_ui(config)
186
+
187
+ for plugin_name, ui_config in all_config_ui.items():
188
+ with st.expander(f"{t('configurations')} {plugin_name}"):
189
+ config[plugin_name] = ui_config
190
+
191
+ if st.button(t("save_button")):
192
+ save_config(config)
193
+ st.success(t("success_message"))
194
+
195
+ else:
196
+ # Exécution du plugin correspondant à l'onglet sélectionné
197
+ for tab in plugin_manager.get_all_tabs():
198
+ if tab['plugin'] == st.session_state.selected_tab_id:
199
+ plugin_manager.run_plugin(tab['plugin'], config)
200
+ break
201
+
202
+ if __name__ == "__main__":
203
+ main()
bart_ft.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from transformers import BartForSequenceClassification, BartTokenizer, get_linear_schedule_with_warmup
4
+ from transformers import AdamW
5
+ from tqdm import tqdm
6
+ import gc
7
+ from plugins.scansite import ScansitePlugin # Assurez-vous que l'import est correct
8
+
9
+ torch.cuda.empty_cache()
10
+
11
+ # Définition du dataset personnalisé
12
+ class PreferenceDataset(Dataset):
13
+ def __init__(self, data, tokenizer, max_length=128):
14
+ self.data = data
15
+ self.tokenizer = tokenizer
16
+ self.max_length = max_length
17
+
18
+ def __len__(self):
19
+ return len(self.data)
20
+
21
+ def __getitem__(self, idx):
22
+ _, title, score = self.data[idx]
23
+ encoding = self.tokenizer(title, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
24
+ # Convertir le score en label binaire (0 ou 1)
25
+ label = 1 if score > 0 else 0
26
+ return {key: val.squeeze(0) for key, val in encoding.items()}, torch.tensor(label, dtype=torch.long)
27
+
28
+ # Fonction principale de finetuning
29
+ def finetune(num_epochs=2, lr=2e-5, weight_decay=0.01, batch_size=1, model_name='facebook/bart-large-mnli', output_model='./bart-large-ft', num_warmup_steps=0):
30
+ print(f"Fine-tuning parameters:\n"
31
+ f"num_epochs: {num_epochs}\n"
32
+ f"learning rate (lr): {lr}\n"
33
+ f"weight_decay: {weight_decay}\n"
34
+ f"batch_size: {batch_size}\n"
35
+ f"model_name: {model_name}\n"
36
+ f"num_warmup_steps: {num_warmup_steps}")
37
+
38
+ # Récupérer les données de référence
39
+ scansite_plugin = ScansitePlugin("scansite", None) # Vous devrez peut-être ajuster ceci en fonction de votre structure
40
+ reference_data_valid, reference_data_rejected = scansite_plugin.get_reference_data()
41
+
42
+ # Combiner les données valides et rejetées
43
+ all_data = reference_data_valid + [(url, title, 0) for url, title in reference_data_rejected]
44
+
45
+ tokenizer = BartTokenizer.from_pretrained(model_name)
46
+ model = BartForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
47
+
48
+ dataset = PreferenceDataset(all_data, tokenizer)
49
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
50
+
51
+ optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
52
+
53
+ total_steps = len(dataloader) * num_epochs
54
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)
55
+
56
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
+ model.to(device)
58
+ model.gradient_checkpointing_enable()
59
+
60
+ for epoch in range(num_epochs):
61
+ model.train()
62
+ for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
63
+ optimizer.zero_grad()
64
+
65
+ inputs, labels = batch
66
+ inputs = {k: v.to(device) for k, v in inputs.items()}
67
+ labels = labels.to(device)
68
+
69
+ outputs = model(**inputs, labels=labels)
70
+ loss = outputs.loss
71
+
72
+ loss.backward()
73
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
74
+ optimizer.step()
75
+ scheduler.step()
76
+
77
+ del inputs, outputs, labels
78
+ torch.cuda.empty_cache()
79
+ gc.collect()
80
+ #print(f"Finetuning en cours round {epoch}/{num_epochs}")
81
+ print("Finetuning terminé sauvegarde en cours.")
82
+ model.save_pretrained(output_model)
83
+ tokenizer.save_pretrained(output_model)
84
+
85
+ print("Finetuning terminé et modèle sauvegardé.")
86
+
87
+ # Appel par défaut si le script est exécuté directement
88
+ if __name__ == "__main__":
89
+ finetune()
embeddings_ft.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from sentence_transformers import SentenceTransformer, losses
4
+ from tqdm import tqdm
5
+ import gc
6
+ from plugins.scansite import ScansitePlugin
7
+
8
+ torch.cuda.empty_cache()
9
+
10
+ class PreferenceDataset(Dataset):
11
+ def __init__(self, data, tokenizer, max_length=128):
12
+ self.data = data
13
+ self.tokenizer = tokenizer
14
+ self.max_length = max_length
15
+
16
+ def __len__(self):
17
+ return len(self.data)
18
+
19
+ def __getitem__(self, idx):
20
+ url, title, score = self.data[idx]
21
+ encoded = self.tokenizer(title, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")
22
+ return {key: val.squeeze(0) for key, val in encoded.items()}, torch.tensor(score, dtype=torch.float)
23
+
24
+ def collate_fn(batch):
25
+ input_ids = torch.stack([item[0]['input_ids'] for item in batch])
26
+ attention_mask = torch.stack([item[0]['attention_mask'] for item in batch])
27
+ scores = torch.stack([item[1] for item in batch])
28
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}, scores
29
+
30
+ def finetune(model_name='nomic-ai/nomic-embed-text-v1', output_model_name="embeddings-ft", num_epochs=2, learning_rate=2e-5, weight_decay=0.01, batch_size=8, num_warmup_steps=0):
31
+ print(f"Fine-tuning parameters:\n"
32
+ f"num_epochs: {num_epochs}\n"
33
+ f"learning rate (lr): {learning_rate}\n"
34
+ f"weight_decay: {weight_decay}\n"
35
+ f"batch_size: {batch_size}\n"
36
+ f"model_name: {model_name}\n"
37
+ f"num_warmup_steps: {num_warmup_steps}")
38
+
39
+ scansite_plugin = ScansitePlugin("scansite", None)
40
+ reference_data_valid, reference_data_rejected = scansite_plugin.get_reference_data()
41
+
42
+ valid_data_with_scores = [(url, title, (score - 1) / 8 + 0.5) for url, title, score in reference_data_valid]
43
+ rejected_data_with_scores = [(url, title, 0.0) for url, title in reference_data_rejected]
44
+
45
+ all_data = valid_data_with_scores + rejected_data_with_scores
46
+
47
+ model = SentenceTransformer(model_name, trust_remote_code=True)
48
+ tokenizer = model.tokenizer
49
+
50
+ dataset = PreferenceDataset(all_data, tokenizer)
51
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
52
+
53
+ loss_function = torch.nn.MSELoss()
54
+
55
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
56
+
57
+ total_steps = len(dataloader) * num_epochs
58
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=total_steps)
59
+
60
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61
+ model.to(device)
62
+
63
+ for epoch in range(num_epochs):
64
+ model.train()
65
+ for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
66
+ input_data, scores = batch
67
+ input_data = {k: v.to(device) for k, v in input_data.items()}
68
+ scores = scores.to(device)
69
+
70
+ optimizer.zero_grad()
71
+
72
+ embeddings = model(input_data)['sentence_embedding']
73
+
74
+ # Calcul de la similarité cosinus
75
+ embeddings_norm = torch.nn.functional.normalize(embeddings, p=2, dim=1)
76
+ cosine_similarities = torch.sum(embeddings_norm, dim=1)
77
+
78
+ # Calcul de la perte
79
+ loss = loss_function(cosine_similarities, scores)
80
+
81
+ loss.backward()
82
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
83
+ optimizer.step()
84
+ scheduler.step()
85
+
86
+ del embeddings, cosine_similarities
87
+ torch.cuda.empty_cache()
88
+ gc.collect()
89
+
90
+ model.save(output_model_name)
91
+
92
+ print("Finetuning terminé et modèle sauvegardé.")
93
+
94
+ if __name__ == "__main__":
95
+ finetune()
plugins/__pycache__/common.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
plugins/__pycache__/ragllm.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
plugins/__pycache__/scansite.cpython-310.pyc ADDED
Binary file (18.5 kB). View file
 
plugins/__pycache__/webrankings.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
plugins/common.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from global_vars import t, translations
2
+ from app import Plugin
3
+ import streamlit as st
4
+ import torch
5
+
6
+ # Ajout des traductions spécifiques à ce plugin
7
+ translations["en"].update({
8
+ "work_directory": "Work Directory",
9
+ "page_title": "News watcher",
10
+ })
11
+ translations["fr"].update({
12
+ "work_directory": "Répertoire de travail",
13
+ "page_title": "Outil de veille",
14
+ })
15
+
16
+ class CommonPlugin(Plugin):
17
+ def get_config_fields(self):
18
+ return {
19
+ "work_directory": {
20
+ "type": "text",
21
+ "label": t("work_directory"),
22
+ "default": "/home/joriel/Vidéos"
23
+ },
24
+ "language": {
25
+ "type": "select",
26
+ "label": t("preferred_language"),
27
+ "options": [("fr", "Français"), ("en", "Anglais")],
28
+ "default": "fr"
29
+ },
30
+ }
31
+
32
+ def get_tabs(self):
33
+ return [{"name": "Commun", "plugin": "common"}]
34
+
35
+ def run(self, config):
36
+ st.header("Common Plugin")
37
+ st.write(f"{t('work_directory')}: {config['common']['work_directory']}")
38
+ torch.cuda.empty_cache()
39
+ st.write("CUDA memory reset")
40
+
41
+ def remove_quotes(s):
42
+ if s.startswith('"') and s.endswith('"'):
43
+ return s[1:-1]
44
+ return s
plugins/ragllm.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from global_vars import translations, t
2
+ from app import Plugin
3
+ import streamlit as st
4
+ import yaml
5
+ from litellm import completion, embedding
6
+ import numpy as np
7
+ from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, manhattan_distances
8
+ import os
9
+ from typing import List, Dict, Any
10
+ import requests
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModel
13
+
14
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ MAX_LENGTH = 512
16
+ CHUNK_SIZE = 200 # Nombre de mots par chunk
17
+
18
+ def mean_pooling(model_output, attention_mask):
19
+ token_embeddings = model_output[0]
20
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
21
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
22
+
23
+ # Ajout des traductions spécifiques à ce plugin
24
+ translations["en"].update({
25
+ "rag_plugin_loaded": "RAG LLM Plugin loaded",
26
+ "rag_enter_text": "Enter RAG text:",
27
+ "rag_enter_question": "Enter your question:",
28
+ "rag_button_get_answer": "Get an answer",
29
+ "rag_success_text_processed": "RAG text processed successfully!",
30
+ "rag_warning_enter_text": "Please enter RAG text.",
31
+ "rag_warning_process_text_first": "Please process the RAG text first.",
32
+ "rag_warning_enter_question": "Please enter a question.",
33
+ "rag_answer": "Answer:",
34
+ "rag_citations": "Citations:",
35
+ "rag_model_provider": "Model Provider",
36
+ "rag_llm_model": "LLM Model",
37
+ "rag_embedder_model": "Embedding Model",
38
+ "rag_similarity_method": "Similarity Method",
39
+ "rag_llm_sys_prompt": "System prompt for LLM",
40
+ "rag_chunk_size": "Chunk size",
41
+ "rag_top_k_chunks": "Number of chunks to use",
42
+ "rag_default_sys_prompt": "You are an AI assistant. Your task is to analyze the provided context and answer questions based ONLY on this context. If the information is not in the context, clearly state that.",
43
+ "rag_error_fetching_models_ollama": "Error fetching Ollama models: ",
44
+ "rag_error_calling_llm": "Error calling LLM: ",
45
+ "rag_processing" : "Processing...",
46
+ })
47
+
48
+ translations["fr"].update({
49
+ "rag_plugin_loaded": "Plugin RAG LLM chargé",
50
+ "rag_enter_text": "Entrez le texte RAG :",
51
+ "rag_enter_question": "Entrez votre question :",
52
+ "rag_button_get_answer": "Obtenir une réponse",
53
+ "rag_success_text_processed": "Texte RAG traité avec succès!",
54
+ "rag_warning_enter_text": "Veuillez entrer du texte RAG.",
55
+ "rag_warning_process_text_first": "Veuillez d'abord traiter le texte RAG.",
56
+ "rag_warning_enter_question": "Veuillez entrer une question.",
57
+ "rag_answer": "Réponse :",
58
+ "rag_citations": "Citations :",
59
+ "rag_model_provider": "Fournisseur de modèle",
60
+ "rag_llm_model": "Modèle LLM",
61
+ "rag_embedder_model": "Modèle d'embedding",
62
+ "rag_similarity_method": "Méthode de similarité",
63
+ "rag_llm_sys_prompt": "Prompt système pour le LLM",
64
+ "rag_chunk_size": "Taille des chunks",
65
+ "rag_top_k_chunks": "Nombre de chunks à utiliser",
66
+ "rag_default_sys_prompt": "Tu es un assistant IA. Ta tâche est d'analyser le contexte fourni et de répondre aux questions en te basant UNIQUEMENT sur ce contexte. Si l'information n'est pas dans le contexte, dis-le clairement.",
67
+ "rag_error_fetching_models_ollama": "Erreur lors de la récupération des modèles Ollama : ",
68
+ "rag_error_calling_llm": "Erreur lors de l'appel au LLM : ",
69
+ "rag_processing" : "En cours de traitement...",
70
+ })
71
+
72
+ class RagllmPlugin(Plugin):
73
+ def __init__(self, name: str, plugin_manager):
74
+ super().__init__(name, plugin_manager)
75
+ self.config = self.load_llm_config()
76
+ self.embeddings = None
77
+ self.chunks = None
78
+
79
+ def load_llm_config(self) -> Dict:
80
+ with open('.llm-config.yml', 'r') as file:
81
+ return yaml.safe_load(file)
82
+
83
+ def get_tabs(self):
84
+ return [{"name": "RAG", "plugin": "ragllm"}]
85
+
86
+ def get_config_fields(self):
87
+ return {
88
+ "provider": {
89
+ "type": "select",
90
+ "label": t("rag_model_provider"),
91
+ "options": [("ollama", "Ollama"), ("groq", "Groq")],
92
+ "default": "ollama"
93
+ },
94
+ "llm_model": {
95
+ "type": "select",
96
+ "label": t("rag_llm_model"),
97
+ "options": [("none", "À charger...")],
98
+ "default": "ollama/qwen2"
99
+ },
100
+ "embedder": {
101
+ "type": "select",
102
+ "label": t("rag_embedder_model"),
103
+ "options": [
104
+ ("sentence-transformers/all-MiniLM-L6-v2", "all-MiniLM-L6-v2"),
105
+ ("nomic-ai/nomic-embed-text-v1.5", "nomic-embed-text-v1.5")
106
+ ],
107
+ "default": "sentence-transformers/all-MiniLM-L6-v2"
108
+ },
109
+ "similarity_method": {
110
+ "type": "select",
111
+ "label": t("rag_similarity_method"),
112
+ "options": [
113
+ ("cosine", "Cosinus"),
114
+ ("euclidean", "Distance euclidienne"),
115
+ ("manhattan", "Distance de Manhattan")
116
+ ],
117
+ "default": "cosine"
118
+ },
119
+ "llm_sys_prompt": {
120
+ "type": "textarea",
121
+ "label": t("rag_llm_sys_prompt"),
122
+ "default": t("rag_default_sys_prompt")
123
+ },
124
+ "chunk_size": {
125
+ "type": "number",
126
+ "label": t("rag_chunk_size"),
127
+ "default": 200
128
+ },
129
+ "top_k": {
130
+ "type": "number",
131
+ "label": t("rag_top_k_chunks"),
132
+ "default": 3
133
+ }
134
+ }
135
+
136
+ def get_config_ui(self, config):
137
+ updated_config = {}
138
+ for field, params in self.get_config_fields().items():
139
+ if params['type'] == 'select':
140
+ if field == 'llm_model':
141
+ provider = config.get('provider', 'ollama')
142
+ models = self.get_available_models(provider)
143
+ try:
144
+ default_index = models.index(config.get(field, params['default']))
145
+ except ValueError:
146
+ default_index = 0
147
+ updated_config[field] = st.selectbox(
148
+ params['label'],
149
+ options=models,
150
+ index=default_index
151
+ )
152
+ else:
153
+ options_list = [option[0] for option in params['options']]
154
+ try:
155
+ default_index = options_list.index(config.get(field, params['default']))
156
+ except ValueError:
157
+ default_index = 0
158
+ updated_config[field] = st.selectbox(
159
+ params['label'],
160
+ options=options_list,
161
+ format_func=lambda x: dict(params['options'])[x],
162
+ index=default_index
163
+ )
164
+ elif params['type'] == 'textarea':
165
+ updated_config[field] = st.text_area(
166
+ params['label'],
167
+ value=config.get(field, params['default'])
168
+ )
169
+ elif params['type'] == 'number':
170
+ updated_config[field] = st.number_input(
171
+ params['label'],
172
+ value=int(config.get(field, params['default'])),
173
+ step=1
174
+ )
175
+ else:
176
+ updated_config[field] = st.text_input(
177
+ params['label'],
178
+ value=config.get(field, params['default'])
179
+ )
180
+ return updated_config
181
+
182
+ def get_sidebar_config_ui(self, config: Dict[str, Any]) -> Dict[str, Any]:
183
+ available_models = self.get_available_models('ollama') + self.get_available_models('groq')
184
+ default_model = config.get('llm_model', available_models[0] if available_models else None)
185
+ selected_model = st.sidebar.selectbox(
186
+ t("rag_llm_model"),
187
+ options=available_models,
188
+ index=available_models.index(default_model) if default_model in available_models else 0,
189
+ key="ragllm_llm_model"
190
+ )
191
+ return {"llm_model": selected_model}
192
+
193
+ def get_available_models(self, provider: str) -> List[str]:
194
+ if provider == 'ollama':
195
+ try:
196
+ response = requests.get("http://localhost:11434/api/tags")
197
+ models = response.json()["models"]
198
+ return [f"ollama/{model['name']}" for model in models] + ["ollama/qwen2"]
199
+ except Exception as e:
200
+ st.error(f"{t('rag_error_fetching_models_ollama')}{str(e)}")
201
+ return ["ollama/qwen2"]
202
+ elif provider == 'groq':
203
+ return ["groq/llama3-70b-8192", "groq/mixtral-8x7b-32768"]
204
+ else:
205
+ return ["none"]
206
+
207
+ def process_rag_text(self, rag_text: str, chunk_size: int, embedder):
208
+ rag_text = rag_text.replace('\\n', ' ').replace('\\\'', "'")
209
+ mots = rag_text.split()
210
+ self.chunks = [' '.join(mots[i:i+chunk_size]) for i in range(0, len(mots), chunk_size)]
211
+ self.embeddings = np.vstack([self.get_embedding(c, embedder) for c in self.chunks])
212
+
213
+ def get_embedding(self, text: str, model: str) -> np.ndarray:
214
+ tokenizer = AutoTokenizer.from_pretrained(model)
215
+ model = AutoModel.from_pretrained(model, trust_remote_code=True).to(DEVICE)
216
+ inputs = tokenizer(text, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(DEVICE)
217
+ with torch.no_grad():
218
+ model_output = model(**inputs)
219
+ return mean_pooling(model_output, inputs['attention_mask']).cpu().numpy()
220
+
221
+ def calculate_similarity(self, query_embedding: np.ndarray, method: str) -> np.ndarray:
222
+ if method == 'cosine':
223
+ return cosine_similarity(query_embedding.reshape(1, -1), self.embeddings)[0]
224
+ elif method == 'euclidean':
225
+ return -euclidean_distances(query_embedding.reshape(1, -1), self.embeddings)[0]
226
+ elif method == 'manhattan':
227
+ return -manhattan_distances(query_embedding.reshape(1, -1), self.embeddings)[0]
228
+ else:
229
+ raise ValueError("Méthode de similarité non reconnue")
230
+
231
+ def get_context(self, query: str, config: Dict[str, Any]) -> tuple:
232
+ query_embedding = self.get_embedding(query, config['ragllm']['embedder'])
233
+ similarities = self.calculate_similarity(query_embedding, config['ragllm']['similarity_method'])
234
+ top_indices = np.argsort(similarities)[-config['ragllm']['top_k']:][::-1]
235
+ context = "\n\n".join([self.chunks[i] for i in top_indices])
236
+ return context, [self.chunks[i] for i in top_indices]
237
+
238
+ def call_llm(self, prompt: str, sysprompt: str) -> str:
239
+ try:
240
+ llm_model = st.session_state.ragllm_llm_model
241
+ #print(f"---------------------------------------\nCalling LLM {llm_model} \n with sysprompt {sysprompt} \n and prompt {prompt} \n and context len of {len(context)}")
242
+ messages = [
243
+ {"role": "system", "content": sysprompt},
244
+ {"role": "user", "content": prompt}
245
+ ]
246
+ response = completion(model=llm_model, messages=messages)
247
+ return response['choices'][0]['message']['content']
248
+ except Exception as e:
249
+ return f"{t('rag_error_calling_llm')}{str(e)}"
250
+
251
+ def free_llm(self):
252
+ try:
253
+ llm_model = st.session_state.ragllm_llm_model
254
+ if llm_model.startswith("ollama/"):
255
+ ollama_model = llm_model.split("/")[1]
256
+ response = requests.post(
257
+ "http://localhost:11434/api/generate",
258
+ json={
259
+ "model": ollama_model,
260
+ "prompt": "bye",
261
+ "keep_alive": 0
262
+ }
263
+ )
264
+ return response.json()['response']
265
+ except Exception as e:
266
+ return f"{t('rag_error_calling_llm')}{str(e)}"
267
+
268
+ def process_with_llm(self, prompt: str, sysprompt: str, context: str) -> str:
269
+ return self.call_llm(f"Contexte : {context}\n\nQuestion : {prompt}", sysprompt)
270
+
271
+ def run(self, config):
272
+ st.write(t("rag_plugin_loaded"))
273
+
274
+ # Initialiser rag_text avec la valeur de session_state si elle existe, sinon utiliser une chaîne vide
275
+ if 'rag_text' not in st.session_state:
276
+ st.session_state.rag_text = ""
277
+ if 'rag_question' not in st.session_state:
278
+ st.session_state.rag_question = "Question"
279
+
280
+ rag_text = st.text_area(t("rag_enter_text"), height=200, value=st.session_state.rag_text, key="rag_text_key")
281
+ user_prompt = st.text_area(t("rag_enter_question"), value=st.session_state.rag_question, key="rag_prompt_key")
282
+ st.session_state.rag_text = rag_text # Mettre à jour la valeur dans session_state
283
+ st.session_state.rag_question = user_prompt
284
+
285
+ if st.button(t("rag_button_get_answer"), key="get_answer_button"):
286
+
287
+ with st.spinner(t("rag_processing")):
288
+ if rag_text:
289
+ self.process_rag_text(rag_text, config['ragllm']['chunk_size'], config['ragllm']['embedder'])
290
+ st.success(t("rag_success_text_processed"))
291
+ else:
292
+ st.warning(t("rag_warning_enter_text"))
293
+ if user_prompt and self.embeddings is not None:
294
+ context, citations = self.get_context(user_prompt, config)
295
+ response = self.process_with_llm(user_prompt, config['ragllm']['llm_sys_prompt'], context)
296
+
297
+ st.write(t("rag_answer"))
298
+ st.write(response)
299
+
300
+ st.write(t("rag_citations"))
301
+ for i, citation in enumerate(citations, 1):
302
+ st.write(f"{i}. {citation[:100]}...")
303
+ elif self.embeddings is None:
304
+ st.warning(t("rag_warning_process_text_first"))
305
+ else:
306
+ st.warning(t("rag_warning_enter_question"))
plugins/scansite.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app import Plugin
2
+ import streamlit as st
3
+ import sqlite3
4
+ import requests
5
+ from bs4 import BeautifulSoup
6
+ from datetime import datetime
7
+ import ollama
8
+ from global_vars import t, translations
9
+
10
+ # Ajout des traductions spécifiques à ce plugin
11
+ translations["en"].update({
12
+ "scansite_title": "News Aggregator",
13
+ "total_links": "Total number of links",
14
+ "annotated_links": "Number of annotated links",
15
+ "known_tags": "Known tags",
16
+ "reset_database": "Reset database",
17
+ "database_reset_success": "Database reset successfully",
18
+ "launch_scan": "Launch scan",
19
+ "scan_complete": "Scan complete",
20
+ "no_articles": "No articles to display.",
21
+ "page": "Page",
22
+ "previous_page": "Previous page",
23
+ "next_page": "Next page",
24
+ "new_articles": "New Articles",
25
+ "rated_articles": "Rated Articles",
26
+ "clicked_not_rated": "Clicked but not rated Articles",
27
+ "tagged_articles": "Tagged Articles",
28
+ "ignored_articles": "Ignored Articles",
29
+ "excluded_articles": "Excluded Articles",
30
+ "rating": "Rating",
31
+ "tags": "Tags",
32
+ "exclude": "Exclude",
33
+ "sources": "Sources",
34
+ "update": "Update",
35
+ "delete": "Delete",
36
+ "add_new_source": "Add a new source (URL)",
37
+ "add_source": "Add source",
38
+ "new_tag": "New tag",
39
+ "new_tag_description": "New tag description",
40
+ "add_tag": "Add tag",
41
+ "work_directory": "Work Directory",
42
+ })
43
+
44
+ translations["fr"].update({
45
+ "scansite_title": "Agrégateur de Nouvelles",
46
+ "total_links": "Nombre total de liens",
47
+ "annotated_links": "Nombre de liens annotés",
48
+ "known_tags": "Tags connus",
49
+ "reset_database": "Réinitialiser la base de données",
50
+ "database_reset_success": "Base de données réinitialisée",
51
+ "launch_scan": "Lancer le scan",
52
+ "scan_complete": "Scan terminé",
53
+ "no_articles": "Aucun article à afficher.",
54
+ "page": "Page",
55
+ "previous_page": "Page précédente",
56
+ "next_page": "Page suivante",
57
+ "new_articles": "Nouveaux Articles",
58
+ "rated_articles": "Articles Notés",
59
+ "clicked_not_rated": "Articles Cliqués non notés",
60
+ "tagged_articles": "Articles Tagués",
61
+ "ignored_articles": "Articles Ignorés",
62
+ "excluded_articles": "Articles Exclus",
63
+ "rating": "Note",
64
+ "tags": "Tags",
65
+ "exclude": "Exclure",
66
+ "sources": "Sources",
67
+ "update": "Mettre à jour",
68
+ "delete": "Supprimer",
69
+ "add_new_source": "Ajouter une nouvelle source (URL)",
70
+ "add_source": "Ajouter source",
71
+ "new_tag": "Nouveau tag",
72
+ "new_tag_description": "Description du nouveau tag",
73
+ "add_tag": "Ajouter tag",
74
+ "work_directory": "Répertoire de travail",
75
+ })
76
+
77
+ class ScansitePlugin(Plugin):
78
+ def __init__(self, name, plugin_manager):
79
+ super().__init__(name, plugin_manager)
80
+ self.conn = self.get_connection()
81
+ self.c = self.conn.cursor()
82
+ self.init_db()
83
+
84
+ def get_connection(self):
85
+ return sqlite3.connect('news_app.db', check_same_thread=False)
86
+
87
+ def init_db(self):
88
+ current_version = self.get_db_version()
89
+ if current_version < 1:
90
+ self.c.execute('''CREATE TABLE IF NOT EXISTS sources
91
+ (id INTEGER PRIMARY KEY, url TEXT, title TEXT)''')
92
+ self.c.execute('''CREATE TABLE IF NOT EXISTS articles
93
+ (id INTEGER PRIMARY KEY, source_id INTEGER, url TEXT UNIQUE, title TEXT, date TEXT,
94
+ is_new INTEGER, is_excluded INTEGER DEFAULT 0)''')
95
+ self.c.execute('''CREATE TABLE IF NOT EXISTS user_actions
96
+ (id INTEGER PRIMARY KEY, article_id INTEGER, action TEXT, rating INTEGER, tags TEXT, timestamp TEXT)''')
97
+ self.c.execute('''CREATE TABLE IF NOT EXISTS tags
98
+ (id INTEGER PRIMARY KEY, name TEXT UNIQUE, description TEXT)''')
99
+ self.set_db_version(1)
100
+
101
+ # Add more version upgrades here
102
+ # if current_version < 2:
103
+ # self.c.execute('''ALTER TABLE articles ADD COLUMN new_column TEXT''')
104
+ # self.set_db_version(2)
105
+
106
+ self.conn.commit()
107
+
108
+ def get_db_version(self):
109
+ self.c.execute('''CREATE TABLE IF NOT EXISTS db_version (version INTEGER)''')
110
+ self.c.execute('SELECT version FROM db_version')
111
+ result = self.c.fetchone()
112
+ return result[0] if result else 0
113
+
114
+ def set_db_version(self, version):
115
+ self.c.execute('INSERT OR REPLACE INTO db_version (rowid, version) VALUES (1, ?)', (version,))
116
+ self.conn.commit()
117
+
118
+ def get_tabs(self):
119
+ return [{"name": t("scansite_title"), "plugin": "scansite"}]
120
+
121
+ def run(self, config):
122
+ st.title(t("scansite_title"))
123
+
124
+ total_links, annotated_links = self.get_stats()
125
+ st.write(f"{t('total_links')} : {total_links}")
126
+ st.write(f"{t('annotated_links')} : {annotated_links}")
127
+
128
+ all_tags = self.get_all_tags()
129
+ st.write(f"{t('known_tags')} :", ", ".join(all_tags))
130
+
131
+ if st.button(t("reset_database")):
132
+ self.reset_database()
133
+ st.success(t("database_reset_success"))
134
+
135
+ if st.button(t("launch_scan")):
136
+ self.launch_scan()
137
+ st.success(t("scan_complete"))
138
+
139
+ self.display_tabs()
140
+
141
+ def get_stats(self):
142
+ total_links = self.c.execute("SELECT COUNT(*) FROM articles WHERE is_excluded = 0").fetchone()[0]
143
+ annotated_links = self.c.execute("""
144
+ SELECT COUNT(DISTINCT article_id) FROM user_actions
145
+ WHERE action IN ('click', 'rate', 'tag')
146
+ """).fetchone()[0]
147
+ return total_links, annotated_links
148
+
149
+ def get_all_tags(self):
150
+ return [row[0] for row in self.c.execute("SELECT name FROM tags").fetchall()]
151
+
152
+ def reset_database(self):
153
+ self.c.execute("DROP TABLE IF EXISTS sources")
154
+ self.c.execute("DROP TABLE IF EXISTS articles")
155
+ self.c.execute("DROP TABLE IF EXISTS user_actions")
156
+ self.c.execute("DROP TABLE IF EXISTS tags")
157
+ self.conn.commit()
158
+ self.init_db()
159
+
160
+ def launch_scan(self):
161
+ sources = self.c.execute("SELECT * FROM sources").fetchall()
162
+ for source in sources:
163
+ self.mark_not_new(source[0])
164
+ links = self.scan_new_links(source[0], source[1])
165
+ for link, title in links:
166
+ self.c.execute("""
167
+ INSERT OR IGNORE INTO articles (source_id, url, title, date, is_new, is_excluded)
168
+ VALUES (?, ?, ?, ?, 1, 0)
169
+ """, (source[0], link, title, datetime.now().strftime('%Y-%m-%d')))
170
+ self.conn.commit()
171
+
172
+ def display_tabs(self):
173
+ tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([
174
+ t("new_articles"), t("rated_articles"), t("clicked_not_rated"),
175
+ t("tagged_articles"), t("ignored_articles"), t("excluded_articles")
176
+ ])
177
+
178
+ all_tags = self.get_all_tags()
179
+
180
+ with tab1:
181
+ st.header(t("new_articles"))
182
+ self.display_paginated_articles(self.get_new_articles(), all_tags, "nouveaux")
183
+
184
+ with tab2:
185
+ st.header(t("rated_articles"))
186
+ self.display_paginated_articles(self.get_rated_articles(), all_tags, "notes")
187
+
188
+ with tab3:
189
+ st.header(t("clicked_not_rated"))
190
+ self.display_paginated_articles(self.get_clicked_not_rated_articles(), all_tags, "cliques")
191
+
192
+ with tab4:
193
+ st.header(t("tagged_articles"))
194
+ self.display_paginated_articles(self.get_tagged_articles(), all_tags, "tagues")
195
+
196
+ with tab5:
197
+ st.header(t("ignored_articles"))
198
+ self.display_paginated_articles(self.get_ignored_articles(), all_tags, "ignores")
199
+
200
+ with tab6:
201
+ st.header(t("excluded_articles"))
202
+ self.display_paginated_articles(self.get_excluded_articles(), all_tags, "exclus")
203
+
204
+ def display_paginated_articles(self, articles, all_tags, tab_name, items_per_page=20):
205
+ if not articles:
206
+ st.write(t("no_articles"))
207
+ return
208
+
209
+ total_pages = (len(articles) - 1) // items_per_page + 1
210
+
211
+ page_key = f"{tab_name}_page"
212
+ if page_key not in st.session_state:
213
+ st.session_state[page_key] = 1
214
+
215
+ page = st.number_input(t("page"), min_value=1, max_value=total_pages, value=st.session_state[page_key], key=f"{tab_name}_number_input")
216
+ st.session_state[page_key] = page
217
+
218
+ start_idx = (page - 1) * items_per_page
219
+ end_idx = start_idx + items_per_page
220
+
221
+ for article in articles[start_idx:end_idx]:
222
+ self.display_article(article, all_tags, tab_name)
223
+
224
+ col1, col2, col3 = st.columns(3)
225
+ with col1:
226
+ if page > 1:
227
+ if st.button(t("previous_page"), key=f"{tab_name}_prev"):
228
+ st.session_state[page_key] = page - 1
229
+ st.rerun()
230
+ with col3:
231
+ if page < total_pages:
232
+ if st.button(t("next_page"), key=f"{tab_name}_next"):
233
+ st.session_state[page_key] = page + 1
234
+ st.rerun()
235
+ with col2:
236
+ st.write(f"{t('page')} {page}/{total_pages}")
237
+
238
+ def display_article(self, article, all_tags, tab_name):
239
+ article_id = article[0]
240
+
241
+ col1, col2, col3, col4, col5 = st.columns([3, 0.5, 1, 2, 1])
242
+
243
+ with col1:
244
+ summary_key = f"{tab_name}_summary_{article_id}"
245
+ if summary_key not in st.session_state:
246
+ st.session_state[summary_key] = None
247
+
248
+ if st.button(article[3], key=f"{tab_name}_article_{article_id}"):
249
+ summary = self.get_article_summary(article[2])
250
+ st.session_state[summary_key] = summary
251
+ self.c.execute("INSERT INTO user_actions (article_id, action, timestamp) VALUES (?, ?, ?)",
252
+ (article_id, 'click', datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
253
+ self.c.execute("UPDATE articles SET is_new = 0 WHERE id = ?", (article_id,))
254
+ self.conn.commit()
255
+
256
+ if st.session_state[summary_key]:
257
+ st.write(st.session_state[summary_key])
258
+
259
+ with col2:
260
+ st.markdown(f"[🔗]({article[2]})")
261
+
262
+ with col3:
263
+ rating_key = f"{tab_name}_rating_{article_id}"
264
+ current_rating = self.get_article_rating(article_id)
265
+ rating = st.slider(t("rating"), 0, 5, current_rating, key=rating_key)
266
+ if rating != current_rating:
267
+ self.c.execute("INSERT INTO user_actions (article_id, action, rating, timestamp) VALUES (?, ?, ?, ?)",
268
+ (article_id, 'rate', rating, datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
269
+ self.conn.commit()
270
+
271
+ with col4:
272
+ tags_key = f"{tab_name}_tags_{article_id}"
273
+ current_tags = self.get_article_tags(article_id)
274
+ selected_tags = st.multiselect(t("tags"), all_tags, default=current_tags, key=tags_key)
275
+ if set(selected_tags) != set(current_tags):
276
+ tags_str = ','.join(selected_tags)
277
+ self.c.execute("INSERT INTO user_actions (article_id, action, tags, timestamp) VALUES (?, ?, ?, ?)",
278
+ (article_id, 'tag', tags_str, datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
279
+ self.conn.commit()
280
+
281
+ with col5:
282
+ exclude_key = f"{tab_name}_exclude_{article_id}"
283
+ if st.button(t("exclude"), key=exclude_key):
284
+ self.c.execute("UPDATE articles SET is_excluded = 1 WHERE id = ?", (article_id,))
285
+ self.conn.commit()
286
+ st.rerun()
287
+
288
+ def get_config_ui(self, config):
289
+ updated_config = {}
290
+
291
+ updated_config['sources'] = st.header(t("sources"))
292
+ sources = self.c.execute("SELECT * FROM sources").fetchall()
293
+ for source in sources:
294
+ col1, col2, col3 = st.columns([3, 1, 1])
295
+ with col1:
296
+ new_title = st.text_input(f"{t('update')} {source[1]}", value=source[2], key=f"source_title_{source[0]}")
297
+ with col2:
298
+ if st.button(t("update"), key=f"update_source_{source[0]}"):
299
+ self.c.execute("UPDATE sources SET title = ? WHERE id = ?", (new_title, source[0]))
300
+ self.conn.commit()
301
+ with col3:
302
+ if st.button(t("delete"), key=f"delete_source_{source[0]}"):
303
+ self.c.execute("DELETE FROM sources WHERE id = ?", (source[0],))
304
+ self.conn.commit()
305
+
306
+ new_url = st.text_input(t("add_new_source"))
307
+ if st.button(t("add_source")):
308
+ title = self.fetch_page_title(new_url)
309
+ self.c.execute("INSERT INTO sources (url, title) VALUES (?, ?)", (new_url, title))
310
+ self.conn.commit()
311
+
312
+ st.header(t("tags"))
313
+ tags = self.get_all_tags_with_descriptions()
314
+ for tag, description in tags:
315
+ col1, col2, col3, col4 = st.columns([2, 3, 1, 1])
316
+ with col1:
317
+ st.text(tag)
318
+ with col2:
319
+ new_description = st.text_input(f"{t('update')} {tag}", value=description, key=f"tag_desc_{tag}")
320
+ with col3:
321
+ if st.button(t("update"), key=f"update_tag_{tag}"):
322
+ self.add_or_update_tag(tag, new_description)
323
+ with col4:
324
+ if st.button(t("delete"), key=f"delete_tag_{tag}"):
325
+ self.delete_tag(tag)
326
+
327
+ new_tag = st.text_input(t("new_tag"))
328
+ new_tag_description = st.text_input(t("new_tag_description"))
329
+ if st.button(t("add_tag")):
330
+ self.add_or_update_tag(new_tag, new_tag_description)
331
+
332
+ # Ajout des configurations modifiées au dictionnaire updated_config
333
+ updated_config["sources"] = sources
334
+ updated_config["new_source_url"] = new_url
335
+ updated_config["tags"] = tags
336
+ updated_config["new_tag"] = new_tag
337
+ updated_config["new_tag_description"] = new_tag_description
338
+
339
+ return updated_config
340
+
341
+ def fetch_page_title(self, url):
342
+ try:
343
+ response = requests.get(url)
344
+ soup = BeautifulSoup(response.text, 'html.parser')
345
+ return soup.title.string
346
+ except:
347
+ return url
348
+
349
+ def mark_not_new(self, source_id):
350
+ self.c.execute("UPDATE articles SET is_new = 0 WHERE source_id = ?", (source_id,))
351
+ self.conn.commit()
352
+
353
+ def scan_new_links(self, source_id, url):
354
+ links = self.scan_links(url)
355
+ filtered_links = []
356
+ for link, title in links:
357
+ self.c.execute("SELECT id, is_excluded FROM articles WHERE url = ?", (link,))
358
+ result = self.c.fetchone()
359
+ if result is None:
360
+ filtered_links.append((link, title))
361
+ return filtered_links
362
+
363
+ def scan_links(self, url):
364
+ links = set()
365
+ try:
366
+ response = requests.get(url)
367
+ soup = BeautifulSoup(response.text, 'html.parser')
368
+ for link in soup.find_all('a'):
369
+ href = link.get('href')
370
+ title = link.text.strip() or href
371
+ if href and href.startswith('http'):
372
+ try:
373
+ article_response = requests.get(href)
374
+ article_soup = BeautifulSoup(article_response.text, 'html.parser')
375
+ if article_soup.find('article'):
376
+ links.add((href, title))
377
+ except:
378
+ pass
379
+ except:
380
+ st.error(f"Erreur lors du scan de {url}")
381
+ return list(links)
382
+
383
+ def get_article_summary(self, url, model="qwen2"):
384
+ prompt = f"Résumez brièvement l'article à cette URL : {url}"
385
+ response = ollama.generate(model=model, prompt=prompt)
386
+ return response['response']
387
+
388
+ def get_new_articles(self):
389
+ return self.c.execute("""
390
+ SELECT * FROM articles
391
+ WHERE is_new = 1
392
+ AND is_excluded = 0
393
+ AND id NOT IN (
394
+ SELECT DISTINCT article_id
395
+ FROM user_actions
396
+ WHERE action IN ('click', 'rate', 'tag')
397
+ )
398
+ ORDER BY date DESC
399
+ """).fetchall()
400
+
401
+ def get_rated_articles(self):
402
+ return self.c.execute("""
403
+ SELECT DISTINCT a.*
404
+ FROM articles a
405
+ JOIN user_actions ua ON a.id = ua.article_id
406
+ WHERE ua.action = 'rate'
407
+ AND a.is_excluded = 0
408
+ ORDER BY ua.timestamp DESC
409
+ """).fetchall()
410
+
411
+ def get_clicked_not_rated_articles(self):
412
+ return self.c.execute("""
413
+ SELECT DISTINCT a.*
414
+ FROM articles a
415
+ JOIN user_actions ua ON a.id = ua.article_id
416
+ WHERE ua.action = 'click'
417
+ AND a.is_excluded = 0
418
+ AND a.id NOT IN (
419
+ SELECT article_id
420
+ FROM user_actions
421
+ WHERE action IN ('rate', 'tag')
422
+ )
423
+ ORDER BY ua.timestamp DESC
424
+ """).fetchall()
425
+
426
+ def get_tagged_articles(self):
427
+ return self.c.execute("""
428
+ SELECT DISTINCT a.*
429
+ FROM articles a
430
+ JOIN user_actions ua ON a.id = ua.article_id
431
+ WHERE ua.action = 'tag'
432
+ AND a.is_excluded = 0
433
+ AND a.id NOT IN (
434
+ SELECT article_id
435
+ FROM user_actions
436
+ WHERE action IN ('rate', 'click')
437
+ )
438
+ ORDER BY ua.timestamp DESC
439
+ """).fetchall()
440
+
441
+ def get_ignored_articles(self):
442
+ return self.c.execute("""
443
+ SELECT * FROM articles
444
+ WHERE is_new = 0
445
+ AND is_excluded = 0
446
+ AND id NOT IN (
447
+ SELECT DISTINCT article_id
448
+ FROM user_actions
449
+ WHERE action IN ('click', 'rate', 'tag')
450
+ )
451
+ ORDER BY date DESC
452
+ """).fetchall()
453
+
454
+ def get_excluded_articles(self):
455
+ return self.c.execute("""
456
+ SELECT * FROM articles
457
+ WHERE is_excluded = 1
458
+ ORDER BY date DESC
459
+ """).fetchall()
460
+
461
+ def get_article_rating(self, article_id):
462
+ self.c.execute("SELECT rating FROM user_actions WHERE article_id = ? AND action = 'rate' ORDER BY timestamp DESC LIMIT 1", (article_id,))
463
+ result = self.c.fetchone()
464
+ return result[0] if result else 0
465
+
466
+ def get_article_tags(self, article_id):
467
+ self.c.execute("SELECT tags FROM user_actions WHERE article_id = ? AND action = 'tag' ORDER BY timestamp DESC LIMIT 1", (article_id,))
468
+ result = self.c.fetchone()
469
+ return result[0].split(',') if result and result[0] else []
470
+
471
+ def get_all_tags_with_descriptions(self):
472
+ return self.c.execute("SELECT name, description FROM tags").fetchall()
473
+
474
+ def add_or_update_tag(self, name, description):
475
+ self.c.execute("INSERT OR REPLACE INTO tags (name, description) VALUES (?, ?)", (name, description))
476
+ self.conn.commit()
477
+
478
+ def delete_tag(self, name):
479
+ self.c.execute("DELETE FROM tags WHERE name = ?", (name,))
480
+ self.conn.commit()
481
+
482
+ def get_reference_data(self):
483
+ # Récupérer les articles avec leur rating
484
+ self.c.execute("""
485
+ SELECT a.id, a.url, a.title, COALESCE(ua.rating, 0) as rating
486
+ FROM articles a
487
+ LEFT JOIN (
488
+ SELECT article_id, rating
489
+ FROM user_actions
490
+ WHERE action = 'rate'
491
+ GROUP BY article_id
492
+ HAVING MAX(timestamp)
493
+ ) ua ON a.id = ua.article_id
494
+ WHERE a.is_excluded = 0
495
+ ORDER BY rating DESC, a.date DESC
496
+ """)
497
+ articles = self.c.fetchall()
498
+
499
+ # Séparer les articles en valides (notés) et rejetés (non notés)
500
+ reference_data_valid = [(article[1], article[2], article[3]) for article in articles if article[3] > 0]
501
+ reference_data_rejected = [(article[1], article[2]) for article in articles if article[3] == 0]
502
+
503
+ return reference_data_valid, reference_data_rejected
plugins/webrankings.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ from bs4 import BeautifulSoup
4
+ import pandas as pd
5
+ import torch
6
+ from transformers import pipeline
7
+ from sentence_transformers import SentenceTransformer, util
8
+ import concurrent.futures
9
+ import time
10
+ import sys
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ from transformers import AutoTokenizer, AutoModel
14
+ import numpy as np
15
+ from scipy import stats
16
+ from PyDictionary import PyDictionary
17
+ import matplotlib.pyplot as plt
18
+ from scipy import stats
19
+ import litellm
20
+ import re
21
+ import sentencepiece
22
+ import random
23
+ from global_vars import t, translations
24
+ from app import Plugin
25
+
26
+ from embeddings_ft import finetune as finetune_embeddings
27
+ from bart_ft import finetune as finetune_bart
28
+
29
+ from webrankings_helper import *
30
+ from plugins.scansite import ScansitePlugin
31
+
32
+ #from data import reference_data_valid, reference_data_rejected
33
+ #reference_data = reference_data_valid + reference_data_rejected
34
+
35
+ # Ajout des traductions spécifiques à ce plugin
36
+ translations["en"].update({
37
+ "webrankings_title": "Comparative os sorter",
38
+ "clear_memory": "Clear Memory",
39
+ "enter_topic": "Enter the topic you're interested in (e.g. longevity):",
40
+ "use_keyword_expansion": "Use keyword expansion",
41
+ "test_content": "Also test link content in addition to titles",
42
+ "select_llm_models": "Select LLM models to use",
43
+ "select_zero_shot_models": "Select zero-shot models to use",
44
+ "select_embedding_models": "Select embedding models to use",
45
+ "analyze_button": "Analyze",
46
+ "loading_models": "Loading models and analyzing links...",
47
+ "expanded_keywords": "Expanded keywords:",
48
+ "analysis_completed": "Analysis completed in {:.2f} seconds",
49
+ "evaluation_results": "Evaluation results with optimal thresholds:",
50
+ "summary_table": "Summary table of scores",
51
+ "optimal_thresholds": "Optimal thresholds:",
52
+ "spearman_comparison": "Comparison of Spearman correlations",
53
+ "methods": "Methods",
54
+ "spearman_correlation": "Spearman correlation coefficient",
55
+ "results_for": "Results for {}",
56
+ "device_info": "Device used for inference: {}",
57
+ "finetune_bart_title": "BART Fine-tuning Interface",
58
+ "finetune_embeddings_title": "Embeddings Fine-tuning Interface",
59
+ })
60
+
61
+ translations["fr"].update({
62
+ "webrankings_title": "Analyseur de classeurs",
63
+ "clear_memory": "Vider la mémoire",
64
+ "enter_topic": "Entrez le sujet qui vous intéresse (ex: longévité):",
65
+ "use_keyword_expansion": "Utiliser l'expansion des mots-clés",
66
+ "test_content": "Tester aussi le contenu des liens en plus des titres",
67
+ "select_llm_models": "Sélectionnez les modèles LLM à utiliser",
68
+ "select_zero_shot_models": "Sélectionnez les modèles zero-shot à utiliser",
69
+ "select_embedding_models": "Sélectionnez les modèles d'embedding à utiliser",
70
+ "analyze_button": "Analyser",
71
+ "loading_models": "Chargement des modèles et analyse des liens...",
72
+ "expanded_keywords": "Mots-clés étendus :",
73
+ "analysis_completed": "Analyse terminée en {:.2f} secondes",
74
+ "evaluation_results": "Résultats d'évaluation avec les seuils optimaux :",
75
+ "summary_table": "Tableau récapitulatif des scores",
76
+ "optimal_thresholds": "Seuils optimaux :",
77
+ "spearman_comparison": "Comparaison des corrélations de Spearman",
78
+ "methods": "Méthodes",
79
+ "spearman_correlation": "Coefficient de corrélation de Spearman",
80
+ "results_for": "Résultats pour {}",
81
+ "device_info": "Dispositif utilisé pour l'inférence : {}",
82
+ "finetune_bart_title": "Interface de Fine-tuning BART",
83
+ "finetune_embeddings_title": "Interface de Fine-tuning des Embeddings",
84
+ })
85
+
86
+
87
+ # Liste des modèles LLM
88
+ llm_models = [] #["ollama/llama3", "ollama/llama3.1", "ollama/qwen2", "ollama/phi3:medium-128k", "ollama/openhermes"]
89
+
90
+ # Liste des modèles zero-shot
91
+ zero_shot_models = [
92
+ ("facebook/bart-large-mnli", "facebook/bart-large-mnli"),
93
+ ("bart-large-ft", "./bart-large-ft")
94
+ #("cross-encoder/nli-deberta-v3-base", "cross-encoder/nli-deberta-v3-base")
95
+ ]
96
+
97
+ # Liste des modèles d'embedding
98
+ embedding_models = [
99
+ ("paraphrase-MiniLM-L6-v2", "paraphrase-MiniLM-L6-v2"),
100
+ ("all-MiniLM-L6-v2", "all-MiniLM-L6-v2"),
101
+ ("nomic-embed-text-v1", "nomic-ai/nomic-embed-text-v1"),
102
+ ("embeddings-ft", "./embeddings-ft")
103
+ ]
104
+
105
+
106
+ class WebrankingsPlugin(Plugin):
107
+ def __init__(self, name, plugin_manager):
108
+ super().__init__(name, plugin_manager)
109
+ self.scansite_plugin = ScansitePlugin('scansite', plugin_manager)
110
+
111
+ def get_tabs(self):
112
+ return [
113
+ {"name": t("webrankings_title"), "plugin": "webrankings"}
114
+ ]
115
+
116
+ def run(self, config):
117
+ tab1, tab2, tab3 = st.tabs([t("webrankings_title"), t("finetune_bart_title"), t("finetune_embeddings_title")])
118
+ reference_data_valid, reference_data_rejected = self.scansite_plugin.get_reference_data()
119
+ reference_data = reference_data_valid + [(url, title, 0) for url, title in reference_data_rejected]
120
+
121
+ with tab1:
122
+ st.title(t("webrankings_title"))
123
+ if st.button(t("clear_memory")):
124
+ torch.cuda.empty_cache()
125
+ torch.cuda.synchronize()
126
+ clear_globals()
127
+ reset_cuda_context()
128
+
129
+ topic = st.text_input(t("enter_topic"), value="longevity, health, healthspan, lifespan")
130
+ use_synonyms = st.checkbox(t("use_keyword_expansion"), value=False)
131
+ check_content = st.checkbox(t("test_content"), value=False)
132
+
133
+ selected_llm_models = st.multiselect(t("select_llm_models"), llm_models, default=llm_models)
134
+ selected_zero_shot_models = st.multiselect(t("select_zero_shot_models"), [m[0] for m in zero_shot_models], default=[m[0] for m in zero_shot_models])
135
+ selected_embedding_models = st.multiselect(t("select_embedding_models"), [m[0] for m in embedding_models], default=[m[0] for m in embedding_models])
136
+
137
+ if st.button(t("analyze_button")):
138
+ with st.spinner(t("loading_models")):
139
+ device = "cuda" if torch.cuda.is_available() else "cpu"
140
+
141
+ # Préparation des modèles
142
+ zero_shot_classifiers = {name: pipeline("zero-shot-classification", model=model, device=device)
143
+ for name, model in zero_shot_models if name in selected_zero_shot_models}
144
+ embedding_models_dict = {}
145
+ for name, model in embedding_models:
146
+ import os
147
+ if name == "embeddings-ft":
148
+ if os.path.exists('./embeddings-ft'):
149
+ embedding_models_dict[name] = SentenceTransformer('./embeddings-ft', trust_remote_code=True).to(device)
150
+ else:
151
+ embedding_models_dict[name] = SentenceTransformer(model, trust_remote_code=True).to(device)
152
+ bert_models = [AutoModel.from_pretrained('bert-base-uncased').to(device)]
153
+ tfidf_objects = [TfidfVectorizer()]
154
+ #release_vram(zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects)
155
+
156
+ # Expansion des mots-clés (utilisant le premier modèle LLM sélectionné)
157
+ if use_synonyms and selected_llm_models:
158
+ expanded_query = []
159
+ for word in topic.split():
160
+ expanded_query.extend(expand_keywords_llm(word, llm_model=selected_llm_models[0]))
161
+ expanded_query = " ".join(expanded_query)
162
+ st.write("Mots-clés étendus :", expanded_query)
163
+ else:
164
+ expanded_query = topic
165
+
166
+ start_time = time.time()
167
+ # Analyse pour chaque lien
168
+ results = []
169
+ for title, link,note in reference_data:
170
+ result = analyze_link(
171
+ title, link, topic, zero_shot_classifiers, embedding_models_dict,
172
+ expanded_query, selected_llm_models, check_content
173
+ )
174
+ if result is not None:
175
+ results.append(result)
176
+ end_time = time.time()
177
+
178
+ # Libération de la mémoire VRAM et des autres ressources
179
+ release_vram(zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects)
180
+
181
+ # Création du DataFrame avec tous les résultats
182
+ df = pd.DataFrame(results)
183
+ print(f"Analyse terminée en {end_time - start_time:.2f} secondes")
184
+
185
+ st.success(t("analysis_completed").format(end_time - start_time))
186
+
187
+ # Évaluation et affichage des résultats
188
+ evaluation_results = {}
189
+ optimal_thresholds = {}
190
+ for column in df.columns:
191
+ if column != "Titre":
192
+ method_scores = df.set_index("Titre")[column].to_dict()
193
+
194
+ optimal_threshold = find_optimal_threshold(
195
+ [item[0] for item in reference_data_valid],
196
+ [item[0] for item in reference_data_rejected],
197
+ method_scores
198
+ )
199
+ optimal_thresholds[column] = optimal_threshold
200
+
201
+ evaluation_results[column] = evaluate_ranking(
202
+ [item[0] for item in reference_data_valid],
203
+ [item[0] for item in reference_data_rejected],
204
+ method_scores,
205
+ optimal_threshold, False
206
+ )
207
+
208
+ # Affichage des résultats
209
+ st.write(t("evaluation_results"))
210
+ eval_df = pd.DataFrame(evaluation_results).T
211
+ st.dataframe(eval_df)
212
+
213
+ st.subheader(t("summary_table"))
214
+ st.dataframe(df)
215
+
216
+ st.write(t("optimal_thresholds"))
217
+ st.json(optimal_thresholds)
218
+
219
+ # Graphique de comparaison des corrélations de Spearman
220
+ spearman_scores = [results['spearman_correlation'] for results in evaluation_results.values()]
221
+ plt.figure(figsize=(15, 8))
222
+ plt.bar(evaluation_results.keys(), spearman_scores)
223
+ plt.title(t("spearman_comparison"))
224
+ plt.xlabel(t("methods"))
225
+ plt.ylabel(t("spearman_correlation"))
226
+ plt.xticks(rotation=90, ha='right')
227
+ plt.tight_layout()
228
+ st.pyplot(plt)
229
+
230
+ # Affichage des résultats pour chaque méthode
231
+ for column in df.columns:
232
+ if column != "Titre":
233
+ st.subheader(f"Résultats pour {column}")
234
+ df_method = df[["Titre", column]].sort_values(column, ascending=False)
235
+ threshold = find_optimal_threshold(
236
+ [item[0] for item in reference_data_valid],
237
+ [item[0] for item in reference_data_rejected],
238
+ df_method.set_index("Titre")[column].to_dict()
239
+ )
240
+ df_method = df_method[df_method[column] > threshold]
241
+ st.dataframe(df_method)
242
+
243
+ with tab2:
244
+ st.title(t("finetune_bart_title"))
245
+ num_epochs = st.number_input("Nombre d'époques", min_value=1, max_value=10, value=2)
246
+ lr = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-1, value=2e-5, format="%.6f", step=1e-5)
247
+ weight_decay = st.number_input("Poids de Décroissance (Weight Decay)", min_value=0.0, max_value=0.1, value=0.01, step=0.005)
248
+ batch_size = st.number_input("Taille du Batch", min_value=1, max_value=16, value=1)
249
+ start = st.slider("Score initial des données valides", min_value=0.0, max_value=1.0, value=0.9, step=0.01)
250
+ model_name = st.text_input("Nom du modèle", value='facebook/bart-large-mnli')
251
+ num_warmup_steps = st.number_input("Nombre d'étapes de Warmup", min_value=0, max_value=100, value=0)
252
+
253
+ # Bouton pour lancer le fine-tuning
254
+ if st.button("Lancer le fine-tuning"):
255
+ with st.spinner("Fine-tuning en cours..."):
256
+ finetune_bart(num_epochs=num_epochs, lr=lr, weight_decay=weight_decay,
257
+ batch_size=batch_size, model_name=model_name, output_model='./bart-large-ft',
258
+ num_warmup_steps=num_warmup_steps)
259
+ st.success("Fine-tuning terminé et modèle sauvegardé.")
260
+
261
+ with tab3:
262
+ st.title(t("finetune_embeddings_title"))
263
+ num_epochs_emb = st.number_input("Nombre d'époques (Embeddings)", min_value=1, max_value=100, value=10)
264
+ lr_emb = st.number_input("Learning Rate (Embeddings)", min_value=1e-6, max_value=1e-1, value=2e-5, format="%.6f", step=5e-6)
265
+ weight_decay_emb = st.number_input("Poids de Décroissance (Weight Decay) (Embeddings)", min_value=0.0, max_value=0.1, value=0.01, step=0.005)
266
+ batch_size_emb = st.number_input("Taille du Batch (Embeddings)", min_value=1, max_value=32, value=16)
267
+ start_emb = st.slider("Score initial des données valides (Embeddings)", min_value=0.0, max_value=1.0, value=0.9, step=0.01)
268
+ model_name_emb = st.selectbox("Modèle d'embeddings de base", ["nomic-ai/nomic-embed-text-v1", "all-MiniLM-L6-v2", "paraphrase-MiniLM-L6-v2"])
269
+ margin_erb = st.slider("Marge (Embeddings)", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
270
+
271
+ # Bouton pour lancer le fine-tuning des embeddings
272
+ if st.button("Lancer le fine-tuning des embeddings"):
273
+ with st.spinner("Fine-tuning des embeddings en cours..."):
274
+
275
+ finetune_embeddings(model_name=model_name_emb, output_model_name="./embeddings-ft",
276
+ num_epochs=num_epochs_emb,
277
+ learning_rate=lr_emb,
278
+ weight_decay=weight_decay_emb,
279
+ batch_size=batch_size_emb,
280
+ )
281
+ st.success("Fine-tuning des embeddings terminé et modèle sauvegardé.")
282
+
283
+ # Affichage de l'information sur le dispositif utilisé
284
+ device = "GPU (CUDA)" if torch.cuda.is_available() else "CPU"
285
+ st.sidebar.info(t("device_info").format(device))
webrankings_helper.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import pipeline
6
+ from sentence_transformers import SentenceTransformer, util
7
+ import concurrent.futures
8
+ import time
9
+ import sys
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ from transformers import AutoTokenizer, AutoModel
13
+ import numpy as np
14
+ from scipy import stats
15
+ from PyDictionary import PyDictionary
16
+ import matplotlib.pyplot as plt
17
+ from scipy import stats
18
+ import litellm
19
+ import re
20
+ import sentencepiece
21
+ import random
22
+
23
+ def score_with_llm(title, topic, llm_model):
24
+ prompt = f"""Evaluate the relevance of the following article to the topic '{topic}'.
25
+ Article title: {title}
26
+ Give a final relevance score between 0 and 1, where 1 is very relevant and 0 is not relevant at all.
27
+ Respond only with a number between 0 and 1."""
28
+
29
+ try:
30
+ response = litellm.completion(
31
+ model=llm_model,
32
+ messages=[{"role": "user", "content": prompt}],
33
+ max_tokens=10
34
+ )
35
+
36
+ score_match = re.search(r'\d+(\.\d+)?', response.choices[0].message.content.strip())
37
+ if score_match:
38
+ score = float(score_match.group())
39
+ print(f"Score LLM : {score}")
40
+ return max(0, min(score, 1))
41
+ else:
42
+ print(f"Could not extract a score from LLM response: {response.choices[0].message.content}")
43
+ return None
44
+ except Exception as e:
45
+ print(f"Error in scoring with LLM {llm_model}: {str(e)}")
46
+ return None
47
+
48
+ def expand_keywords_llm(keyword, max_synonyms=3, llm_model="ollama/qwen2"):
49
+ prompt = f"""Please provide up to {max_synonyms} synonyms or closely related terms for the word or phrase: "{keyword}".
50
+ Return only the list of synonyms, separated by commas, without any additional explanation."""
51
+
52
+ try:
53
+ response = litellm.completion(
54
+ model=llm_model,
55
+ messages=[{"role": "user", "content": prompt}],
56
+ max_tokens=50
57
+ )
58
+
59
+ synonyms = [s.strip() for s in response.choices[0].message.content.split(',')]
60
+ return [keyword] + synonyms[:max_synonyms]
61
+ except Exception as e:
62
+ print(f"Error in expanding keywords with LLM {llm_model}: {str(e)}")
63
+ return [keyword]
64
+
65
+ # Fonction pour obtenir les liens de la page d'accueil
66
+ def get_homepage_links(url):
67
+ response = requests.get(url)
68
+ soup = BeautifulSoup(response.text, 'html.parser')
69
+ links = soup.find_all('a', href=True)
70
+ return [(link.text.strip(), link['href']) for link in links if link.text.strip()]
71
+
72
+ # Fonction pour obtenir le contenu d'un article
73
+ def get_article_content(url):
74
+ try:
75
+ print(f"Récupération du contenu de : {url}")
76
+ response = requests.get(url)
77
+ print(f"Taille de la réponse HTTP : {len(response.content)} octets") # Affiche le nombre d'octets de la réponse HTTP
78
+ soup = BeautifulSoup(response.text, 'html.parser')
79
+ print(f"Taille de l'objet soup : {sys.getsizeof(soup)} octets") # Affiche la taille en mémoire de l'objet soup
80
+ article = soup.find('article')
81
+ if article:
82
+ paragraphs = article.find_all('p')
83
+ content = ' '.join([p.text for p in paragraphs])
84
+ print(f"Paragraphes récupéré : {len(content)} caractères")
85
+ return content
86
+ print("Aucun contenu d'article trouvé")
87
+ return ""
88
+ except Exception as e:
89
+ print(f"Erreur lors de la récupération du contenu : {str(e)}")
90
+ return ""
91
+
92
+ # Fonction pour l'analyse zero-shot
93
+ def zero_shot_analysis(text, topic, classifier):
94
+ if not text:
95
+ print("Texte vide pour l'analyse zero-shot")
96
+ return 0.0
97
+ result = classifier(text, candidate_labels=[topic, f"not {topic}"], multi_label=False)
98
+ print(f"Score zero-shot : {result['scores'][0]}")
99
+ return result['scores'][0]
100
+
101
+ # Fonction pour l'analyse par embeddings
102
+ def embedding_analysis(text, topic_embedding, model):
103
+ if not text:
104
+ print("Texte vide pour l'analyse par embeddings")
105
+ return 0.0
106
+ text_embedding = model.encode([text], convert_to_tensor=True)
107
+ similarity = util.pytorch_cos_sim(text_embedding, topic_embedding).item()
108
+ print(f"Score embedding : {similarity}")
109
+ return similarity
110
+
111
+ from sklearn.feature_extraction.text import TfidfVectorizer
112
+ from sklearn.metrics.pairwise import cosine_similarity
113
+ #import nltk
114
+ #from nltk.corpus import wordnet
115
+ #nltk.download('wordnet')
116
+
117
+ def preprocess_text(text):
118
+ # Tokenize the text
119
+ tokens = text.lower().split()
120
+ # Expand each token with its synonyms
121
+ expanded_tokens = []
122
+ for token in tokens:
123
+ synonyms = set()
124
+ for syn in wordnet.synsets(token):
125
+ for lemma in syn.lemmas():
126
+ synonyms.add(lemma.name())
127
+ expanded_tokens.extend(list(synonyms))
128
+ return ' '.join(expanded_tokens)
129
+
130
+ def improved_tfidf_similarity(texts, query):
131
+ # Preprocess texts and query
132
+ preprocessed_texts = [preprocess_text(text) for text in texts]
133
+ preprocessed_query = preprocess_text(query)
134
+
135
+ # Combine texts and query for vectorization
136
+ all_texts = preprocessed_texts + [preprocessed_query]
137
+
138
+ # Use TfidfVectorizer with custom parameters
139
+ vectorizer = TfidfVectorizer(ngram_range=(1, 2), min_df=1, smooth_idf=True)
140
+ tfidf_matrix = vectorizer.fit_transform(all_texts)
141
+
142
+ # Calculate cosine similarity
143
+ cosine_similarities = cosine_similarity(tfidf_matrix[-1], tfidf_matrix[:-1]).flatten()
144
+
145
+ # Normalize similarities to avoid zero scores
146
+ normalized_similarities = (cosine_similarities - cosine_similarities.min()) / (cosine_similarities.max() - cosine_similarities.min())
147
+
148
+ return normalized_similarities
149
+
150
+ from sklearn.feature_extraction.text import TfidfVectorizer
151
+ from sklearn.metrics.pairwise import cosine_similarity
152
+ import numpy as np
153
+
154
+ def improved_tfidf_similarity_v2(texts, query):
155
+ # Combine texts and query, treating each word or phrase as a separate document
156
+ all_docs = [word.strip() for text in texts for word in text.split(',')] + [word.strip() for word in query.split(',')]
157
+
158
+ # Create TF-IDF matrix
159
+ vectorizer = TfidfVectorizer()
160
+ tfidf_matrix = vectorizer.fit_transform(all_docs)
161
+
162
+ # Calculate document vectors by summing the TF-IDF vectors of their words
163
+ doc_vectors = []
164
+ query_vector = np.zeros((1, tfidf_matrix.shape[1]))
165
+
166
+ current_doc = 0
167
+ for i, doc in enumerate(all_docs):
168
+ if i < len(all_docs) - len(query.split(',')): # If it's part of the texts
169
+ if current_doc == len(texts):
170
+ break
171
+ if doc in texts[current_doc]:
172
+ doc_vectors.append(tfidf_matrix[i].toarray())
173
+ else:
174
+ current_doc += 1
175
+ doc_vectors.append(tfidf_matrix[i].toarray())
176
+ else: # If it's part of the query
177
+ query_vector += tfidf_matrix[i].toarray()
178
+
179
+ doc_vectors = np.array([np.sum(doc, axis=0) for doc in doc_vectors])
180
+
181
+ # Calculate cosine similarity
182
+ similarities = cosine_similarity(query_vector, doc_vectors).flatten()
183
+
184
+ # Normalize similarities to avoid zero scores
185
+ normalized_similarities = (similarities - similarities.min()) / (similarities.max() - similarities.min() + 1e-8)
186
+
187
+ return normalized_similarities
188
+
189
+ # Example usage:
190
+ # texts = ["longevity, health, aging", "computer science, AI"]
191
+ # query = "longevity, life extension, anti-aging"
192
+ # results = improved_tfidf_similarity_v2(texts, query)
193
+ # print(results)
194
+
195
+ # Nouvelles fonctions
196
+ def tfidf_similarity(texts, query):
197
+ vectorizer = TfidfVectorizer()
198
+ tfidf_matrix = vectorizer.fit_transform(texts + [query])
199
+ cosine_similarities = cosine_similarity(tfidf_matrix[-1], tfidf_matrix[:-1]).flatten()
200
+ return cosine_similarities
201
+
202
+ def bert_similarity(texts, query, model_name='bert-base-uncased'):
203
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
204
+ model = AutoModel.from_pretrained(model_name)
205
+
206
+ def get_embedding(text):
207
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
208
+ with torch.no_grad():
209
+ outputs = model(**inputs)
210
+ return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
211
+
212
+ query_embedding = get_embedding(query)
213
+ text_embeddings = [get_embedding(text) for text in texts]
214
+
215
+ similarities = [cosine_similarity([query_embedding], [text_embedding])[0][0] for text_embedding in text_embeddings]
216
+ return similarities
217
+
218
+ # Fonction principale d'analyse modifiée
219
+ def analyze_link(title, link, topic, zero_shot_classifiers, embedding_models, expanded_query, llm_models, testcontent):
220
+ print(f"\nAnalyse de : {title}")
221
+
222
+ results = {
223
+ "Titre": title,
224
+ #"TF-IDF (titre)": improved_tfidf_similarity_v2([title], expanded_query)[0],
225
+ #"BERT (titre)": bert_similarity([title], expanded_query)[0],
226
+ }
227
+
228
+ # Zero-shot analysis
229
+ for name, classifier in zero_shot_classifiers.items():
230
+ results[f"Zero-shot (titre) - {name}"] = zero_shot_analysis(title, topic, classifier)
231
+
232
+ # Embedding analysis
233
+ for name, model in embedding_models.items():
234
+ topic_embedding = model.encode([expanded_query], convert_to_tensor=True)
235
+ results[f"Embeddings (titre) - {name}"] = embedding_analysis(title, topic_embedding, model)
236
+
237
+ # LLM analysis
238
+ for model in llm_models:
239
+ results[f"LLM Score - {model}"] = score_with_llm(title, topic, model)
240
+
241
+ if testcontent:
242
+ content = get_article_content(link)
243
+ #results["TF-IDF (contenu)"] = improved_tfidf_similarity_v2([content], expanded_query)[0]
244
+ #results["BERT (contenu)"]= bert_similarity([content], expanded_query)[0]
245
+
246
+ # Zero-shot analysis
247
+ for name, classifier in zero_shot_classifiers.items():
248
+ results[f"Zero-shot (contenu) - {name}"] = zero_shot_analysis(content, topic, classifier)
249
+
250
+ # Embedding analysis
251
+ for name, model in embedding_models.items():
252
+ topic_embedding = model.encode([expanded_query], convert_to_tensor=True)
253
+ results[f"Embeddings (contenu) - {name}"] = embedding_analysis(content, topic_embedding, model)
254
+
255
+ # LLM analysis
256
+ for model in llm_models:
257
+ results[f"LLM Content Score - {model}"] = score_with_llm(content, topic, model)
258
+
259
+
260
+ return results
261
+
262
+ from scipy import stats
263
+
264
+ def evaluate_ranking(reference_data_valid, reference_data_rejected, method_scores, threshold, silent):
265
+ simple_score = 0
266
+ true_positives = 0
267
+ false_positives = 0
268
+ true_negatives = 0
269
+ false_negatives = 0
270
+
271
+ # Créer une liste de tous les éléments avec leur statut (1 pour valide, 0 pour rejeté)
272
+ all_items = [(item, 1) for item in reference_data_valid] + [(item, 0) for item in reference_data_rejected]
273
+
274
+ # Trier les éléments selon leur score dans la méthode
275
+ all_items_temp = all_items.copy()
276
+ # correct false positive if method spit out same score for all
277
+ #random.shuffle(all_items_temp)
278
+ all_items_temp.reverse()
279
+ sorted_method = sorted([(item, method_scores.get(item, 0)) for item, _ in all_items_temp],
280
+ key=lambda x: x[1], reverse=True)
281
+
282
+ # Créer des listes pour le calcul de la corrélation de Spearman
283
+ reference_ranks = []
284
+ method_ranks = []
285
+
286
+ for i, (item, status) in enumerate(all_items):
287
+ method_score = method_scores.get(item, 0)
288
+ method_rank = next(j for j, (it, score) in enumerate(sorted_method) if it == item)
289
+
290
+ reference_ranks.append(i)
291
+ method_ranks.append(method_rank)
292
+
293
+ if status == 1: # Item valide
294
+ if method_score >= threshold:
295
+ simple_score += 1
296
+ true_positives += 1
297
+ else:
298
+ simple_score -= 1
299
+ false_negatives += 1
300
+ else: # Item rejeté
301
+ if method_score < threshold:
302
+ simple_score += 1
303
+ true_negatives += 1
304
+ else:
305
+ simple_score -= 1
306
+ false_positives += 1
307
+
308
+ # Calculer le coefficient de corrélation de Spearman
309
+ if not silent:
310
+ print("+++")
311
+ print(reference_ranks)
312
+ print("---")
313
+ print(method_ranks)
314
+ spearman_corr, _ = stats.spearmanr(reference_ranks, method_ranks)
315
+
316
+ # Calculer la précision, le rappel et le F1-score
317
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
318
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
319
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
320
+
321
+ return {
322
+ "simple_score": simple_score,
323
+ "spearman_correlation": spearman_corr,
324
+ "precision": precision,
325
+ "recall": recall,
326
+ "f1_score": f1_score,
327
+ }
328
+
329
+ def find_optimal_threshold(reference_data_valid, reference_data_rejected, method_scores):
330
+ best_score = float('-inf')
331
+ best_threshold = 0
332
+
333
+ for threshold in np.arange(0, 1.05, 0.05):
334
+ result = evaluate_ranking(
335
+ reference_data_valid,
336
+ reference_data_rejected,
337
+ method_scores,
338
+ threshold, True
339
+ )
340
+ if result['simple_score'] > best_score:
341
+ best_score = result['simple_score']
342
+ best_threshold = threshold
343
+
344
+ return best_threshold
345
+
346
+ def reset_cuda_context():
347
+ torch.cuda.empty_cache()
348
+ torch.cuda.ipc_collect()
349
+ if torch.cuda.is_available():
350
+ torch.cuda.set_device(torch.cuda.current_device())
351
+ torch.cuda.synchronize()
352
+
353
+ import gc
354
+ def clear_models():
355
+ global zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects
356
+
357
+ for classifier in zero_shot_classifiers.values():
358
+ del classifier
359
+ zero_shot_classifiers.clear()
360
+
361
+ for model in embedding_models_dict.values():
362
+ del model
363
+ embedding_models_dict.clear()
364
+
365
+ for model in bert_models:
366
+ del model
367
+ bert_models.clear()
368
+
369
+ for vectorizer in tfidf_objects:
370
+ del vectorizer
371
+ tfidf_objects.clear()
372
+
373
+ torch.cuda.empty_cache()
374
+ gc.collect()
375
+
376
+ def clear_globals():
377
+ for name in list(globals()):
378
+ if isinstance(globals()[name], (torch.nn.Module, torch.Tensor)):
379
+ del globals()[name]
380
+
381
+ def release_vram(zero_shot_classifiers, embedding_models, bert_models, tfidf_objects):
382
+ # Supprimer les objets zero-shot classifiers
383
+ for model in zero_shot_classifiers.values():
384
+ del model
385
+
386
+ # Supprimer les objets embedding models
387
+ for model in embedding_models.values():
388
+ del model
389
+
390
+ # Supprimer les objets bert models
391
+ for model in bert_models:
392
+ del model
393
+
394
+ # Supprimer les objets tfidf objects
395
+ for obj in tfidf_objects:
396
+ del obj
397
+
398
+ # Vider le cache de la mémoire GPU
399
+ torch.cuda.empty_cache()
400
+ torch.cuda.synchronize()
401
+ gc.collect()
402
+ clear_globals()
403
+ reset_cuda_context()
404
+
405
+ def load_finetuned_model(model_path):
406
+ checkpoint = torch.load(model_path)
407
+ base_model = AutoModel.from_pretrained(checkpoint['base_model_name'])
408
+ model = EmbeddingModel(base_model)
409
+ model.load_state_dict(checkpoint['model_state_dict'])
410
+ return model