|
--- |
|
language: en |
|
tags: |
|
- topic-drift |
|
- conversation-analysis |
|
- pytorch |
|
- attention |
|
- lstm |
|
license: mit |
|
datasets: |
|
- leonvanbokhorst/topic-drift-v2 |
|
metrics: |
|
- rmse |
|
- r2_score |
|
model-index: |
|
- name: topic-drift-detector |
|
results: |
|
- task: |
|
type: topic-drift-detection |
|
name: Topic Drift Detection |
|
dataset: |
|
name: leonvanbokhorst/topic-drift-v2 |
|
type: conversations |
|
metrics: |
|
- name: Test RMSE |
|
type: rmse |
|
value: 0.0139 |
|
- name: Test R² |
|
type: r2 |
|
value: 0.8766 |
|
- name: Test Loss |
|
type: loss |
|
value: 0.0002 |
|
--- |
|
|
|
# Topic Drift Detector Model |
|
|
|
## Version: v20241225_162244 |
|
|
|
This model detects topic drift in conversations using an enhanced attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset. |
|
|
|
## Model Architecture |
|
- Multi-head attention mechanism (4 heads) |
|
- Bidirectional LSTM (3 layers) for pattern detection |
|
- Dynamic weight generation |
|
- Semantic bridge detection |
|
- Hidden dimension: 512 |
|
- Dropout rate: 0.2 |
|
|
|
## Performance Metrics |
|
```txt |
|
=== Full Training Results === |
|
Best Validation RMSE: 0.0133 |
|
Best Validation R²: 0.8873 |
|
|
|
=== Test Set Results === |
|
Loss: 0.0002 |
|
RMSE: 0.0139 |
|
R²: 0.8766 |
|
|
|
``` |
|
|
|
## Training Curves |
|
 |
|
|
|
## Usage |
|
```python |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
# Load base embedding model |
|
base_model = AutoModel.from_pretrained('BAAI/bge-m3') |
|
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3') |
|
|
|
# Load topic drift detector |
|
model = torch.load('models/v20241225_162244/topic_drift_model.pt') |
|
model.eval() |
|
|
|
# Prepare conversation window (8 turns) |
|
conversation = [ |
|
"How was your weekend?", |
|
"It was great! Went hiking.", |
|
"Which trail did you take?", |
|
"The mountain loop trail.", |
|
"That's nice. By the way, did you watch the game?", |
|
"Yes! What an amazing match!", |
|
"The final score was incredible.", |
|
"I couldn't believe that last-minute goal." |
|
] |
|
|
|
# Get embeddings |
|
with torch.no_grad(): |
|
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt') |
|
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024] |
|
|
|
# Reshape for model input [1, 8*1024] |
|
conversation_embeddings = embeddings.view(1, -1) |
|
|
|
# Get drift score |
|
drift_scores = model(conversation_embeddings) |
|
|
|
print(f"Topic drift score: {drift_scores.item():.4f}") |
|
# Higher scores indicate more topic drift |
|
``` |
|
|
|
## Training Details |
|
- Dataset: [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) |
|
- Window size: 8 turns |
|
- Batch size: 32 |
|
- Learning rate: 0.0001 |
|
- Early stopping patience: 10 |
|
- Total epochs: 70 (early stopped) |
|
- Training framework: PyTorch |
|
- Base embeddings: BAAI/bge-m3 |
|
|
|
## Limitations |
|
- Works best with English conversations |
|
- Requires exactly 8 turns of conversation |
|
- Each turn should be between 1-512 tokens |
|
- Relies on BAAI/bge-m3 embeddings |
|
|