Spaces:
Sleeping
Sleeping
File size: 2,404 Bytes
44a025a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple, Any, Optional
# Global variables to store model and data
_model = None
_question_embeddings = None
_answer_embeddings = None
_qa_data = None
def initialize_model() -> None:
"""
Initialize the model once and store it in a global variable.
"""
global _model
if _model is None:
_model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
return _model
def get_model() -> SentenceTransformer:
"""
Get the loaded model or initialize it if not loaded.
"""
global _model
if _model is None:
_model = initialize_model()
return _model
def load_embeddings() -> Tuple[np.ndarray, np.ndarray, List[Dict[str, str]]]:
"""
Load embeddings and QA data from files.
"""
global _question_embeddings, _answer_embeddings, _qa_data
try:
_question_embeddings = np.load("data/question_embeddings.npy")
_answer_embeddings = np.load("data/answer_embeddings.npy")
with open("data/qa_data.json", "r", encoding="utf-8") as f:
_qa_data = json.load(f)
return _question_embeddings, _answer_embeddings, _qa_data
except FileNotFoundError as e:
print(f"Warning: Embeddings not found. {str(e)}")
return None, None, None
except Exception as e:
print(f"Error loading embeddings: {str(e)}")
return None, None, None
def get_embeddings() -> (
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[List[Dict[str, str]]]]
):
"""
Get the loaded embeddings or load them if not loaded.
"""
global _question_embeddings, _answer_embeddings, _qa_data
if _question_embeddings is None or _answer_embeddings is None or _qa_data is None:
_question_embeddings, _answer_embeddings, _qa_data = load_embeddings()
return _question_embeddings, _answer_embeddings, _qa_data
def reload_embeddings() -> bool:
"""
Reload embeddings from files.
"""
global _question_embeddings, _answer_embeddings, _qa_data
try:
_question_embeddings, _answer_embeddings, _qa_data = load_embeddings()
print(f"Embeddings reloaded successfully. {len(_qa_data)} QA pairs available.")
return True
except Exception as e:
print(f"Error reloading embeddings: {str(e)}")
return False
|