VishnuPottabatthini commited on
Commit
0006ae1
·
verified ·
1 Parent(s): a76270d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -1,11 +1,7 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  from transformers import BartTokenizer, BartForConditionalGeneration
4
  import torch
5
 
6
- # Initialize the FastAPI app
7
- app = FastAPI()
8
-
9
  # Load the fine-tuned BART model and tokenizer from the local directory
10
  MODEL_DIR = './BART model small/model'
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -13,17 +9,12 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  tokenizer = BartTokenizer.from_pretrained(MODEL_DIR)
14
  model = BartForConditionalGeneration.from_pretrained(MODEL_DIR).to(device)
15
 
16
- # Define a request model for the API input
17
- class Article(BaseModel):
18
- text: str
19
-
20
- # API Endpoint for summarization
21
- @app.post("/summarize")
22
- async def summarize(article: Article):
23
  try:
24
  # Tokenize the input article
25
  inputs = tokenizer(
26
- article.text,
27
  return_tensors="pt",
28
  max_length=1024,
29
  truncation=True
@@ -42,8 +33,20 @@ async def summarize(article: Article):
42
  # Decode the summary
43
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
44
 
45
- # Return the summary in the response
46
- return {"summary": summary}
47
 
48
  except Exception as e:
49
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
2
  from transformers import BartTokenizer, BartForConditionalGeneration
3
  import torch
4
 
 
 
 
5
  # Load the fine-tuned BART model and tokenizer from the local directory
6
  MODEL_DIR = './BART model small/model'
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
9
  tokenizer = BartTokenizer.from_pretrained(MODEL_DIR)
10
  model = BartForConditionalGeneration.from_pretrained(MODEL_DIR).to(device)
11
 
12
+ # Define the summarization function
13
+ def summarize(text):
 
 
 
 
 
14
  try:
15
  # Tokenize the input article
16
  inputs = tokenizer(
17
+ text,
18
  return_tensors="pt",
19
  max_length=1024,
20
  truncation=True
 
33
  # Decode the summary
34
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
35
 
36
+ return summary
 
37
 
38
  except Exception as e:
39
+ return str(e)
40
+
41
+ # Create Gradio interface
42
+ # Textbox input for the article and output for the summary
43
+ interface = gr.Interface(
44
+ fn=summarize, # The function to summarize the article
45
+ inputs="text", # Input is a text box where users can input the article text
46
+ outputs="text", # Output is a text box displaying the summary
47
+ title="BART Summarization", # The title of the app
48
+ description="Enter an article to generate a summary using a fine-tuned BART model."
49
+ )
50
+
51
+ # Launch the Gradio app
52
+ interface.launch()