Spaces:
Running
Running
johannoriel
commited on
Commit
•
8e46350
1
Parent(s):
645a356
HF bug correction
Browse files- plugins/ragllm.py +28 -9
plugins/ragllm.py
CHANGED
@@ -16,6 +16,7 @@ from langchain_huggingface import HuggingFaceEmbeddings
|
|
16 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
17 |
MAX_LENGTH = 512
|
18 |
CHUNK_SIZE = 200
|
|
|
19 |
|
20 |
def mean_pooling(model_output, attention_mask):
|
21 |
token_embeddings = model_output[0]
|
@@ -52,6 +53,7 @@ translations["en"].update({
|
|
52 |
"rag_error_calling_llm": "Error calling LLM: ",
|
53 |
"rag_processing" : "Processing...",
|
54 |
"rag_hf_api_key": "HuggingFace API Token",
|
|
|
55 |
})
|
56 |
|
57 |
translations["fr"].update({
|
@@ -77,25 +79,23 @@ translations["fr"].update({
|
|
77 |
"rag_error_calling_llm": "Erreur lors de l'appel au LLM : ",
|
78 |
"rag_processing" : "En cours de traitement...",
|
79 |
"rag_hf_api_key": "Token API HuggingFace",
|
|
|
80 |
})
|
81 |
|
82 |
class RagllmPlugin(Plugin):
|
83 |
def __init__(self, name: str, plugin_manager):
|
84 |
super().__init__(name, plugin_manager)
|
85 |
-
try:
|
86 |
-
self.config = self.load_llm_config()
|
87 |
-
except:
|
88 |
-
self.config = {}
|
89 |
self.embeddings = None
|
90 |
self.chunks = None
|
91 |
self.hf_client = None
|
|
|
92 |
|
93 |
def load_llm_config(self) -> Dict:
|
94 |
-
|
95 |
-
|
96 |
-
return yaml.safe_load(file)
|
97 |
-
except:
|
98 |
return {}
|
|
|
|
|
99 |
|
100 |
def get_tabs(self):
|
101 |
return [{"name": "RAG", "plugin": "ragllm"}]
|
@@ -161,6 +161,12 @@ class RagllmPlugin(Plugin):
|
|
161 |
|
162 |
def get_config_ui(self, config):
|
163 |
updated_config = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
for field, params in self.get_config_fields().items():
|
165 |
if params['type'] == 'select':
|
166 |
if field == 'llm_model':
|
@@ -203,11 +209,24 @@ class RagllmPlugin(Plugin):
|
|
203 |
params['label'],
|
204 |
value=config.get(field, params['default'])
|
205 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
return updated_config
|
207 |
|
208 |
def get_sidebar_config_ui(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
209 |
-
|
|
|
210 |
default_model = config.get('llm_model', available_models[0] if available_models else None)
|
|
|
|
|
|
|
|
|
211 |
selected_model = st.sidebar.selectbox(
|
212 |
t("rag_llm_model"),
|
213 |
options=available_models,
|
|
|
16 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
17 |
MAX_LENGTH = 512
|
18 |
CHUNK_SIZE = 200
|
19 |
+
CONFIG_FILE = '.llm-config.yml'
|
20 |
|
21 |
def mean_pooling(model_output, attention_mask):
|
22 |
token_embeddings = model_output[0]
|
|
|
53 |
"rag_error_calling_llm": "Error calling LLM: ",
|
54 |
"rag_processing" : "Processing...",
|
55 |
"rag_hf_api_key": "HuggingFace API Token",
|
56 |
+
"rag_config_file_missing": "Configuration file .llm-config.yml not found. This is required for Ollama and Groq providers.",
|
57 |
})
|
58 |
|
59 |
translations["fr"].update({
|
|
|
79 |
"rag_error_calling_llm": "Erreur lors de l'appel au LLM : ",
|
80 |
"rag_processing" : "En cours de traitement...",
|
81 |
"rag_hf_api_key": "Token API HuggingFace",
|
82 |
+
"rag_config_file_missing": "Fichier de configuration .llm-config.yml non trouvé. Ce fichier est nécessaire pour les providers Ollama et Groq.",
|
83 |
})
|
84 |
|
85 |
class RagllmPlugin(Plugin):
|
86 |
def __init__(self, name: str, plugin_manager):
|
87 |
super().__init__(name, plugin_manager)
|
|
|
|
|
|
|
|
|
88 |
self.embeddings = None
|
89 |
self.chunks = None
|
90 |
self.hf_client = None
|
91 |
+
self.config = {}
|
92 |
|
93 |
def load_llm_config(self) -> Dict:
|
94 |
+
if not os.path.exists(CONFIG_FILE):
|
95 |
+
st.warning(t("rag_config_file_missing"))
|
|
|
|
|
96 |
return {}
|
97 |
+
with open(CONFIG_FILE, 'r') as file:
|
98 |
+
return yaml.safe_load(file)
|
99 |
|
100 |
def get_tabs(self):
|
101 |
return [{"name": "RAG", "plugin": "ragllm"}]
|
|
|
161 |
|
162 |
def get_config_ui(self, config):
|
163 |
updated_config = {}
|
164 |
+
|
165 |
+
# Load config file only if provider is not huggingface
|
166 |
+
current_provider = config.get('provider', 'ollama')
|
167 |
+
if current_provider != 'huggingface':
|
168 |
+
self.config = self.load_llm_config()
|
169 |
+
|
170 |
for field, params in self.get_config_fields().items():
|
171 |
if params['type'] == 'select':
|
172 |
if field == 'llm_model':
|
|
|
209 |
params['label'],
|
210 |
value=config.get(field, params['default'])
|
211 |
)
|
212 |
+
|
213 |
+
if config.get('provider') == 'huggingface':
|
214 |
+
updated_config['hf_api_key'] = st.text_input(
|
215 |
+
t("rag_hf_api_key"),
|
216 |
+
type="password",
|
217 |
+
value=config.get('hf_api_key', '')
|
218 |
+
)
|
219 |
+
|
220 |
return updated_config
|
221 |
|
222 |
def get_sidebar_config_ui(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
223 |
+
provider = config.get('provider', 'ollama')
|
224 |
+
available_models = self.get_available_models(provider)
|
225 |
default_model = config.get('llm_model', available_models[0] if available_models else None)
|
226 |
+
|
227 |
+
if default_model not in available_models:
|
228 |
+
default_model = available_models[0] if available_models else None
|
229 |
+
|
230 |
selected_model = st.sidebar.selectbox(
|
231 |
t("rag_llm_model"),
|
232 |
options=available_models,
|