Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from transformers import BertTokenizer, EncoderDecoderModel | |
tokenizer = BertTokenizer.from_pretrained("cahya/bert2gpt-indonesian-summarization") | |
tokenizer.bos_token = tokenizer.cls_token | |
tokenizer.eos_token = tokenizer.sep_token | |
model = EncoderDecoderModel.from_pretrained("cahya/bert2gpt-indonesian-summarization") | |
class Generate(BaseModel): | |
text: str | |
class Prompt(BaseModel): | |
text: str | |
def generate(prompt: str): | |
if prompt == '': | |
return Generate(text='Prompt not provided') | |
else: | |
# generate summary | |
input_ids = tokenizer.encode(prompt, return_tensors='pt') | |
summary_ids = model.generate(input_ids, | |
min_length=20, | |
max_length=80, | |
num_beams=10, | |
repetition_penalty=2.5, | |
length_penalty=1.0, | |
early_stopping=True, | |
no_repeat_ngram_size=2, | |
use_cache=True, | |
do_sample = True, | |
temperature = 0.8, | |
top_k = 50, | |
top_p = 0.95) | |
summary_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return Generate(text=summary_text) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def home(): | |
return {'app':'Summarization', 'version': .1} | |
def inference(prompt: Prompt): | |
return generate(prompt.text) |