Spaces:
Runtime error
Runtime error
File size: 969 Bytes
7d22c1d 0b0f289 7d22c1d 62a4b51 7d22c1d 62a4b51 7d22c1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from transformers import RobertaForMaskedLM, RobertaTokenizer
from fastapi import FastAPI, HTTPException
app = FastAPI()
# Load the pre-trained model and tokenizer
model = RobertaForMaskedLM.from_pretrained('roberta-base')
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# Load your dataset, in this case "cyberpunk_lore.txt"
with open("cyberpunk_lore.txt", "r") as f:
dataset = f.read()
# Train the model on your dataset
input_ids = torch.tensor([tokenizer.encode(dataset, add_special_tokens=True)])
model.train()
model.zero_grad()
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
loss.backward()
# Serve the model via FastAPI
@app.post("/predict")
def predict(prompt: str):
input_ids = torch.tensor([tokenizer.encode(prompt, add_special_tokens=True)])
outputs = model(input_ids)
generated_text = tokenizer.decode(outputs[0].argmax(dim=1).tolist()[0])
return {"generated_text": generated_text}
|