File size: 969 Bytes
7d22c1d
 
 
0b0f289
 
 
7d22c1d
 
 
62a4b51
7d22c1d
62a4b51
7d22c1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from transformers import RobertaForMaskedLM, RobertaTokenizer
from fastapi import FastAPI, HTTPException

app = FastAPI()

# Load the pre-trained model and tokenizer
model = RobertaForMaskedLM.from_pretrained('roberta-base')
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Load your dataset, in this case "cyberpunk_lore.txt"
with open("cyberpunk_lore.txt", "r") as f:
    dataset = f.read()

# Train the model on your dataset
input_ids = torch.tensor([tokenizer.encode(dataset, add_special_tokens=True)])
model.train()
model.zero_grad()
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
loss.backward()

# Serve the model via FastAPI
@app.post("/predict")
def predict(prompt: str):
    input_ids = torch.tensor([tokenizer.encode(prompt, add_special_tokens=True)])
    outputs = model(input_ids)
    generated_text = tokenizer.decode(outputs[0].argmax(dim=1).tolist()[0])
    return {"generated_text": generated_text}