VishnuPottabatthini commited on
Commit
a76270d
·
verified ·
1 Parent(s): 13bb3d4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import BartTokenizer, BartForConditionalGeneration
4
+ import torch
5
+
6
+ # Initialize the FastAPI app
7
+ app = FastAPI()
8
+
9
+ # Load the fine-tuned BART model and tokenizer from the local directory
10
+ MODEL_DIR = './BART model small/model'
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ tokenizer = BartTokenizer.from_pretrained(MODEL_DIR)
14
+ model = BartForConditionalGeneration.from_pretrained(MODEL_DIR).to(device)
15
+
16
+ # Define a request model for the API input
17
+ class Article(BaseModel):
18
+ text: str
19
+
20
+ # API Endpoint for summarization
21
+ @app.post("/summarize")
22
+ async def summarize(article: Article):
23
+ try:
24
+ # Tokenize the input article
25
+ inputs = tokenizer(
26
+ article.text,
27
+ return_tensors="pt",
28
+ max_length=1024,
29
+ truncation=True
30
+ ).to(device)
31
+
32
+ # Generate the summary
33
+ summary_ids = model.generate(
34
+ inputs['input_ids'],
35
+ attention_mask=inputs['attention_mask'],
36
+ max_length=150, # Set maximum length for the summary
37
+ min_length=30, # Set minimum length for the summary
38
+ num_beams=4, # Use beam search to generate the summary
39
+ early_stopping=True
40
+ )
41
+
42
+ # Decode the summary
43
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
44
+
45
+ # Return the summary in the response
46
+ return {"summary": summary}
47
+
48
+ except Exception as e:
49
+ raise HTTPException(status_code=500, detail=str(e))