am_text_summary / train /fast_text_trainer.py
berito's picture
train code added
a608bb4
raw
history blame
1.05 kB
import fasttext
class FastTextTrainer:
def __init__(self, corpus_file):
self.corpus_file = corpus_file
self.model_file = "fasttext_model.bin"
self.model = None
def train_model(self, model_type="skipgram", dim=100, epoch=5, lr=0.05, thread=4):
print("Training FastText model...")
self.model = fasttext.train_unsupervised(
input=self.corpus_file,
model=model_type,
dim=dim,
epoch=epoch,
lr=lr,
thread=thread
)
self.model.save_model(self.model_file)
print(f"Model trained and saved to {self.model_file}")
def load_model(self):
print(f"Loading FastText model from {self.model_file}...")
self.model = fasttext.load_model(self.model_file)
print("Model loaded successfully.")
def get_word_vector(self, word):
if self.model is None:
raise ValueError("Model not loaded. Use `train_model` or `load_model` first.")
return self.model.get_word_vector(word)