aiml_gr16 / main.py
SurajSingh's picture
Update main.py
4bf0241 verified
raw
history blame
1.4 kB
from fastapi import FastAPI
import os
os.environ["TRANSFORMERS_CACHE"] = "./checkpoint"
from transformers import LineByLineTextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import Trainer, TrainingArguments
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def load_model(model_path):
model = GPT2LMHeadModel.from_pretrained(model_path)
return model
def load_tokenizer(checkpoint):
tokenizer = GPT2Tokenizer.from_pretrained(checkpoint)
return tokenizer
model_path = r'./checkpoint/'
model = load_model(model_path)
tokenizer = load_tokenizer('gpt2')
def generate_text(sequence, max_new_tokens):
ids = tokenizer.encode(f'{sequence}', return_tensors='pt')
input_length = ids.size(1)
max_length = input_length + max_new_tokens
final_outputs = model.generate(
ids,
do_sample=True,
max_length=max_length,
pad_token_id=model.config.eos_token_id
)
return tokenizer.decode(final_outputs[0], skip_special_tokens=True)
@app.get("/subject/{prompt}")
async def root(prompt: str):
print(prompt)
return {"subject": generate_text("Email : " + prompt + " Subject : ", 7).split('Subject : ')[1]}