from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import BertTokenizer, EncoderDecoderModel tokenizer = BertTokenizer.from_pretrained("cahya/bert2gpt-indonesian-summarization") tokenizer.bos_token = tokenizer.cls_token tokenizer.eos_token = tokenizer.sep_token model = EncoderDecoderModel.from_pretrained("cahya/bert2gpt-indonesian-summarization") class Generate(BaseModel): text: str class Prompt(BaseModel): text: str def generate(prompt: str): if prompt == '': return Generate(text='Prompt not provided') else: # generate summary input_ids = tokenizer.encode(prompt, return_tensors='pt') summary_ids = model.generate(input_ids, min_length=20, max_length=80, num_beams=10, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True, no_repeat_ngram_size=2, use_cache=True, do_sample = True, temperature = 0.8, top_k = 50, top_p = 0.95) summary_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return Generate(text=summary_text) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get('/') def home(): return {'app':'Summarization', 'version': .1} @app.post('/generate', response_model=Generate) def inference(prompt: Prompt): return generate(prompt.text)