import fasttext
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

class FastTextSummarizer:
    def __init__(self, model_file):
        self.model = fasttext.load_model(model_file)

    def sentence_embedding(self, sentence):
        words = sentence.split()
        word_vectors = [self.model.get_word_vector(word) for word in words if word in self.model.words]
        if word_vectors:
            return np.mean(word_vectors, axis=0)
        else:
            return np.zeros(self.model.get_dimension())

    def summarize(self, text, num_sentences=3):
        # Split text into sentences (adjust for Amharic, e.g., using '።')
        sentences = text.split('።')  # Amharic sentence delimiter
        sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
        # Compute embeddings for each sentence
        sentence_embeddings = [self.sentence_embedding(sentence) for sentence in sentences]
        # Compute document embedding as the mean of sentence embeddings
        document_embedding = np.mean(sentence_embeddings, axis=0)
        # Calculate similarity between each sentence and the document
        similarities = cosine_similarity([document_embedding], sentence_embeddings).flatten()
        # Rank sentences by similarity
        ranked_indices = similarities.argsort()[::-1]
        top_indices = ranked_indices[:num_sentences]
        # Extract the top sentences
        summary = '። '.join([sentences[i] for i in sorted(top_indices)]) + '።'
        return summary