Spaces:
Runtime error
Runtime error
File size: 1,344 Bytes
1fcd3cd 235cb56 1fcd3cd b83beb7 1fcd3cd 8aad2af 1fcd3cd c58e248 1fcd3cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from fastapi import FastAPI
from transformers import LineByLineTextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import Trainer, TrainingArguments
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def load_model(model_path):
model = GPT2LMHeadModel.from_pretrained(model_path)
return model
def load_tokenizer(checkpoint):
tokenizer = GPT2Tokenizer.from_pretrained(checkpoint)
return tokenizer
model_path = r'./checkpoint/'
model = load_model(model_path)
tokenizer = load_tokenizer('./tokenizer/')
def generate_text(sequence, max_new_tokens):
ids = tokenizer.encode(f'{sequence}', return_tensors='pt')
input_length = ids.size(1)
max_length = input_length + max_new_tokens
final_outputs = model.generate(
ids,
do_sample=True,
max_length=max_length,
pad_token_id=model.config.eos_token_id
)
return tokenizer.decode(final_outputs[0], skip_special_tokens=True)
@app.get("/subject/{prompt}")
async def root(prompt: str):
print(prompt)
return {"subject": generate_text("Email : " + prompt + " Subject : ", 7).split('Subject : ')[1]}
|