Taiken_chatbot_API / app /services /model_service.py
vumichien's picture
Add application file
44a025a
raw
history blame
2.4 kB
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple, Any, Optional
# Global variables to store model and data
_model = None
_question_embeddings = None
_answer_embeddings = None
_qa_data = None
def initialize_model() -> None:
"""
Initialize the model once and store it in a global variable.
"""
global _model
if _model is None:
_model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
return _model
def get_model() -> SentenceTransformer:
"""
Get the loaded model or initialize it if not loaded.
"""
global _model
if _model is None:
_model = initialize_model()
return _model
def load_embeddings() -> Tuple[np.ndarray, np.ndarray, List[Dict[str, str]]]:
"""
Load embeddings and QA data from files.
"""
global _question_embeddings, _answer_embeddings, _qa_data
try:
_question_embeddings = np.load("data/question_embeddings.npy")
_answer_embeddings = np.load("data/answer_embeddings.npy")
with open("data/qa_data.json", "r", encoding="utf-8") as f:
_qa_data = json.load(f)
return _question_embeddings, _answer_embeddings, _qa_data
except FileNotFoundError as e:
print(f"Warning: Embeddings not found. {str(e)}")
return None, None, None
except Exception as e:
print(f"Error loading embeddings: {str(e)}")
return None, None, None
def get_embeddings() -> (
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[List[Dict[str, str]]]]
):
"""
Get the loaded embeddings or load them if not loaded.
"""
global _question_embeddings, _answer_embeddings, _qa_data
if _question_embeddings is None or _answer_embeddings is None or _qa_data is None:
_question_embeddings, _answer_embeddings, _qa_data = load_embeddings()
return _question_embeddings, _answer_embeddings, _qa_data
def reload_embeddings() -> bool:
"""
Reload embeddings from files.
"""
global _question_embeddings, _answer_embeddings, _qa_data
try:
_question_embeddings, _answer_embeddings, _qa_data = load_embeddings()
print(f"Embeddings reloaded successfully. {len(_qa_data)} QA pairs available.")
return True
except Exception as e:
print(f"Error reloading embeddings: {str(e)}")
return False