johannoriel commited on
Commit
8e46350
1 Parent(s): 645a356

HF bug correction

Browse files
Files changed (1) hide show
  1. plugins/ragllm.py +28 -9
plugins/ragllm.py CHANGED
@@ -16,6 +16,7 @@ from langchain_huggingface import HuggingFaceEmbeddings
16
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  MAX_LENGTH = 512
18
  CHUNK_SIZE = 200
 
19
 
20
  def mean_pooling(model_output, attention_mask):
21
  token_embeddings = model_output[0]
@@ -52,6 +53,7 @@ translations["en"].update({
52
  "rag_error_calling_llm": "Error calling LLM: ",
53
  "rag_processing" : "Processing...",
54
  "rag_hf_api_key": "HuggingFace API Token",
 
55
  })
56
 
57
  translations["fr"].update({
@@ -77,25 +79,23 @@ translations["fr"].update({
77
  "rag_error_calling_llm": "Erreur lors de l'appel au LLM : ",
78
  "rag_processing" : "En cours de traitement...",
79
  "rag_hf_api_key": "Token API HuggingFace",
 
80
  })
81
 
82
  class RagllmPlugin(Plugin):
83
  def __init__(self, name: str, plugin_manager):
84
  super().__init__(name, plugin_manager)
85
- try:
86
- self.config = self.load_llm_config()
87
- except:
88
- self.config = {}
89
  self.embeddings = None
90
  self.chunks = None
91
  self.hf_client = None
 
92
 
93
  def load_llm_config(self) -> Dict:
94
- try:
95
- with open('.llm-config.yml', 'r') as file:
96
- return yaml.safe_load(file)
97
- except:
98
  return {}
 
 
99
 
100
  def get_tabs(self):
101
  return [{"name": "RAG", "plugin": "ragllm"}]
@@ -161,6 +161,12 @@ class RagllmPlugin(Plugin):
161
 
162
  def get_config_ui(self, config):
163
  updated_config = {}
 
 
 
 
 
 
164
  for field, params in self.get_config_fields().items():
165
  if params['type'] == 'select':
166
  if field == 'llm_model':
@@ -203,11 +209,24 @@ class RagllmPlugin(Plugin):
203
  params['label'],
204
  value=config.get(field, params['default'])
205
  )
 
 
 
 
 
 
 
 
206
  return updated_config
207
 
208
  def get_sidebar_config_ui(self, config: Dict[str, Any]) -> Dict[str, Any]:
209
- available_models = self.get_available_models('ollama') + self.get_available_models('groq')
 
210
  default_model = config.get('llm_model', available_models[0] if available_models else None)
 
 
 
 
211
  selected_model = st.sidebar.selectbox(
212
  t("rag_llm_model"),
213
  options=available_models,
 
16
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  MAX_LENGTH = 512
18
  CHUNK_SIZE = 200
19
+ CONFIG_FILE = '.llm-config.yml'
20
 
21
  def mean_pooling(model_output, attention_mask):
22
  token_embeddings = model_output[0]
 
53
  "rag_error_calling_llm": "Error calling LLM: ",
54
  "rag_processing" : "Processing...",
55
  "rag_hf_api_key": "HuggingFace API Token",
56
+ "rag_config_file_missing": "Configuration file .llm-config.yml not found. This is required for Ollama and Groq providers.",
57
  })
58
 
59
  translations["fr"].update({
 
79
  "rag_error_calling_llm": "Erreur lors de l'appel au LLM : ",
80
  "rag_processing" : "En cours de traitement...",
81
  "rag_hf_api_key": "Token API HuggingFace",
82
+ "rag_config_file_missing": "Fichier de configuration .llm-config.yml non trouvé. Ce fichier est nécessaire pour les providers Ollama et Groq.",
83
  })
84
 
85
  class RagllmPlugin(Plugin):
86
  def __init__(self, name: str, plugin_manager):
87
  super().__init__(name, plugin_manager)
 
 
 
 
88
  self.embeddings = None
89
  self.chunks = None
90
  self.hf_client = None
91
+ self.config = {}
92
 
93
  def load_llm_config(self) -> Dict:
94
+ if not os.path.exists(CONFIG_FILE):
95
+ st.warning(t("rag_config_file_missing"))
 
 
96
  return {}
97
+ with open(CONFIG_FILE, 'r') as file:
98
+ return yaml.safe_load(file)
99
 
100
  def get_tabs(self):
101
  return [{"name": "RAG", "plugin": "ragllm"}]
 
161
 
162
  def get_config_ui(self, config):
163
  updated_config = {}
164
+
165
+ # Load config file only if provider is not huggingface
166
+ current_provider = config.get('provider', 'ollama')
167
+ if current_provider != 'huggingface':
168
+ self.config = self.load_llm_config()
169
+
170
  for field, params in self.get_config_fields().items():
171
  if params['type'] == 'select':
172
  if field == 'llm_model':
 
209
  params['label'],
210
  value=config.get(field, params['default'])
211
  )
212
+
213
+ if config.get('provider') == 'huggingface':
214
+ updated_config['hf_api_key'] = st.text_input(
215
+ t("rag_hf_api_key"),
216
+ type="password",
217
+ value=config.get('hf_api_key', '')
218
+ )
219
+
220
  return updated_config
221
 
222
  def get_sidebar_config_ui(self, config: Dict[str, Any]) -> Dict[str, Any]:
223
+ provider = config.get('provider', 'ollama')
224
+ available_models = self.get_available_models(provider)
225
  default_model = config.get('llm_model', available_models[0] if available_models else None)
226
+
227
+ if default_model not in available_models:
228
+ default_model = available_models[0] if available_models else None
229
+
230
  selected_model = st.sidebar.selectbox(
231
  t("rag_llm_model"),
232
  options=available_models,