Taiken_chatbot_API / app /services /embedding_service.py
vumichien's picture
Add application file
44a025a
raw
history blame
3.24 kB
import pandas as pd
import json
import re
import numpy as np
import os
from typing import List, Dict, Tuple, Any
from app.services.model_service import get_model, reload_embeddings
# Ensure data directory exists
os.makedirs("data", exist_ok=True)
def remove_prefix(text: str, prefix_pattern: str) -> str:
"""
Removes the prefix matching the given pattern from the text.
"""
return re.sub(prefix_pattern, "", text).strip()
def process_file(file_path: str, file_type: str) -> List[Dict[str, str]]:
"""
Process Excel or CSV file and extract question-answer pairs.
"""
if file_type == "excel":
df = pd.read_excel(file_path)
elif file_type == "csv":
df = pd.read_csv(file_path)
else:
raise ValueError("Unsupported file type. Use 'excel' or 'csv'.")
# Check if the necessary columns exist
if "θ³ͺ問" not in df.columns or "ε›žη­”" not in df.columns:
raise ValueError("The file must contain 'θ³ͺ問' and 'ε›žη­”' columns.")
# Initialize the list to store processed data
qa_list = []
df.dropna(inplace=True)
# Iterate over each row in the DataFrame
for index, row in df.iterrows():
raw_question = str(row["θ³ͺ問"])
raw_answer = str(row["ε›žη­”"])
# Remove prefixes using regex patterns
question = remove_prefix(raw_question, r"^Q\d+\.\s*")
answer = remove_prefix(raw_answer, r"^A\.\s*")
qa_list.append({"question": question, "answer": answer})
# print(qa_list)
return qa_list
def save_raw_data(qa_list: List[Dict[str, str]]) -> None:
"""
Save the raw question-answer pairs to a JSON file.
"""
with open("data/raw.json", "w", encoding="utf-8") as json_file:
json.dump(qa_list, json_file, ensure_ascii=False, indent=2)
def create_and_save_embeddings(qa_list: List[Dict[str, str]]) -> None:
"""
Create embeddings for questions and answers and save them.
"""
questions = [item["question"] for item in qa_list]
answers = [item["answer"] for item in qa_list]
# Use the global model
model = get_model()
# Create embeddings for questions and answers
question_embeddings = model.encode(questions, convert_to_numpy=True)
answer_embeddings = model.encode(answers, convert_to_numpy=True)
# Save embeddings as numpy arrays
np.save("data/question_embeddings.npy", question_embeddings)
np.save("data/answer_embeddings.npy", answer_embeddings)
# Save the original data
with open("data/qa_data.json", "w", encoding="utf-8") as f:
json.dump(qa_list, f, ensure_ascii=False, indent=2)
def process_and_create_embeddings(file_path: str, file_type: str) -> Dict[str, Any]:
"""
Process the input file and create embeddings.
"""
try:
qa_list = process_file(file_path, file_type)
save_raw_data(qa_list)
create_and_save_embeddings(qa_list)
# Reload embeddings into memory
reload_embeddings()
return {
"status": "success",
"message": "Embeddings created successfully",
"data": {"total_qa_pairs": len(qa_list)},
}
except Exception as e:
return {"status": "error", "message": str(e)}