RETRAN / README.md
Leore42's picture
Update README.md
cd7d193 verified
---
license: apache-2.0
datasets:
- Leore42/RETAN
language:
- en
---
This is an extremely tiny model that summarizes text into 4 words
you will need to download the config, tokenizer and model and use this pytho ncode as a starting point:
```python
import torch
import tkinter as tk
import json
import torch.nn as nn
import math
class ThemeExtractor(nn.Module):
def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=1, dropout=0.1):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=128, dropout=dropout)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(d_model, 1)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
emb = self.embedding(x) * math.sqrt(self.d_model)
emb = emb.transpose(0, 1)
enc_out = self.encoder(emb)
enc_out = enc_out.transpose(0, 1)
enc_out = self.dropout(enc_out)
logits = self.fc(enc_out).squeeze(-1)
return logits
def load_model_and_tokenizer():
with open('config.json', 'r') as f:
config = json.load(f)
with open('tokenizer.json', 'r') as f:
vocab = json.load(f)
vocab_size = config["vocab_size"]
model = ThemeExtractor(vocab_size, d_model=64, nhead=4, num_layers=1, dropout=0.2)
model.load_state_dict(torch.load("theme_extractor.pth"))
return model, vocab, config
def generate_text(model, vocab, config, input_text):
inv_vocab = {v: k for k, v in vocab.items()}
max_len = config["max_len"]
tokens = input_text.lower().split()
token_ids = [vocab.get(token, vocab["<unk>"]) for token in tokens]
if len(token_ids) < max_len:
token_ids += [vocab["<pad>"]] * (max_len - len(token_ids))
else:
token_ids = token_ids[:max_len]
input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device)
model.eval()
with torch.no_grad():
logits = model(input_tensor)
probs = torch.sigmoid(logits).squeeze(0)
topk = torch.topk(probs, 4)
indices = topk.indices.cpu().numpy()
selected = sorted(indices, key=lambda i: i)
theme_words = [tokens[i] for i in selected if i < len(tokens)]
return ' '.join(theme_words)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, vocab, config = load_model_and_tokenizer()
def on_generate():
input_text = entry_input.get()
generated = generate_text(model, vocab, config, input_text)
label_output.config(text="Generated Themes: " + generated)
root = tk.Tk()
root.title("Theme Extractor")
entry_input = tk.Entry(root, width=50)
entry_input.pack(pady=10)
button_generate = tk.Button(root, text="Generate Themes", command=on_generate)
button_generate.pack(pady=10)
label_output = tk.Label(root, text="Generated Themes: ", wraplength=400)
label_output.pack(pady=10)
root.mainloop()
```