File size: 3,242 Bytes
44a025a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import pandas as pd
import json
import re
import numpy as np
import os
from typing import List, Dict, Tuple, Any

from app.services.model_service import get_model, reload_embeddings

# Ensure data directory exists
os.makedirs("data", exist_ok=True)


def remove_prefix(text: str, prefix_pattern: str) -> str:
    """
    Removes the prefix matching the given pattern from the text.
    """
    return re.sub(prefix_pattern, "", text).strip()


def process_file(file_path: str, file_type: str) -> List[Dict[str, str]]:
    """
    Process Excel or CSV file and extract question-answer pairs.
    """
    if file_type == "excel":
        df = pd.read_excel(file_path)
    elif file_type == "csv":
        df = pd.read_csv(file_path)
    else:
        raise ValueError("Unsupported file type. Use 'excel' or 'csv'.")

    # Check if the necessary columns exist
    if "θ³ͺ問" not in df.columns or "ε›žη­”" not in df.columns:
        raise ValueError("The file must contain 'θ³ͺ問' and 'ε›žη­”' columns.")

    # Initialize the list to store processed data
    qa_list = []
    df.dropna(inplace=True)
    # Iterate over each row in the DataFrame
    for index, row in df.iterrows():
        raw_question = str(row["θ³ͺ問"])
        raw_answer = str(row["ε›žη­”"])

        # Remove prefixes using regex patterns
        question = remove_prefix(raw_question, r"^Q\d+\.\s*")
        answer = remove_prefix(raw_answer, r"^A\.\s*")

        qa_list.append({"question": question, "answer": answer})
        # print(qa_list)

    return qa_list


def save_raw_data(qa_list: List[Dict[str, str]]) -> None:
    """
    Save the raw question-answer pairs to a JSON file.
    """
    with open("data/raw.json", "w", encoding="utf-8") as json_file:
        json.dump(qa_list, json_file, ensure_ascii=False, indent=2)


def create_and_save_embeddings(qa_list: List[Dict[str, str]]) -> None:
    """
    Create embeddings for questions and answers and save them.
    """
    questions = [item["question"] for item in qa_list]
    answers = [item["answer"] for item in qa_list]

    # Use the global model
    model = get_model()

    # Create embeddings for questions and answers
    question_embeddings = model.encode(questions, convert_to_numpy=True)
    answer_embeddings = model.encode(answers, convert_to_numpy=True)

    # Save embeddings as numpy arrays
    np.save("data/question_embeddings.npy", question_embeddings)
    np.save("data/answer_embeddings.npy", answer_embeddings)

    # Save the original data
    with open("data/qa_data.json", "w", encoding="utf-8") as f:
        json.dump(qa_list, f, ensure_ascii=False, indent=2)


def process_and_create_embeddings(file_path: str, file_type: str) -> Dict[str, Any]:
    """
    Process the input file and create embeddings.
    """
    try:
        qa_list = process_file(file_path, file_type)
        save_raw_data(qa_list)
        create_and_save_embeddings(qa_list)

        # Reload embeddings into memory
        reload_embeddings()

        return {
            "status": "success",
            "message": "Embeddings created successfully",
            "data": {"total_qa_pairs": len(qa_list)},
        }
    except Exception as e:
        return {"status": "error", "message": str(e)}