File size: 1,702 Bytes
944c27e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import BertTokenizer, EncoderDecoderModel

tokenizer = BertTokenizer.from_pretrained("cahya/bert2gpt-indonesian-summarization")
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token
model = EncoderDecoderModel.from_pretrained("cahya/bert2gpt-indonesian-summarization")


class Generate(BaseModel):
    text: str

class Prompt(BaseModel):
    text: str

def generate(prompt: str):
    if prompt == '':
        return Generate(text='Prompt not provided')
    else:
        # generate summary
        input_ids = tokenizer.encode(prompt, return_tensors='pt')
        summary_ids = model.generate(input_ids,
                    min_length=20,
                    max_length=80, 
                    num_beams=10,
                    repetition_penalty=2.5, 
                    length_penalty=1.0, 
                    early_stopping=True,
                    no_repeat_ngram_size=2,
                    use_cache=True,
                    do_sample = True,
                    temperature = 0.8,
                    top_k = 50,
                    top_p = 0.95)

        summary_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return Generate(text=summary_text)        

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get('/')
def home():
    return {'app':'Summarization', 'version': .1}

@app.post('/generate', response_model=Generate)
def inference(prompt: Prompt):
    return generate(prompt.text)