sashdev commited on
Commit
97c782c
·
verified ·
1 Parent(s): 5d50a44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -31
app.py CHANGED
@@ -1,35 +1,20 @@
1
- import gradio as gr
2
  import torch
3
- from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
 
5
- # Load DistilBERT model and tokenizer
6
- model_name = "bhadresh-savani/distilbert-base-uncased-finetuned-sentiment"
7
- tokenizer = DistilBertTokenizer.from_pretrained(model_name)
8
- model = DistilBertForSequenceClassification.from_pretrained(model_name)
9
 
10
- # Use GPU if available
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model.to(device)
 
 
 
 
 
13
 
14
- # Define the prediction function
15
- def predict_sentiment(text):
16
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
17
- outputs = model(**inputs)
18
- predictions = torch.argmax(outputs.logits, dim=-1)
19
- return predictions.item()
20
-
21
- # Gradio interface
22
- with gr.Blocks() as sentiment_app:
23
- gr.Markdown("<h1>Sentiment Analysis with DistilBERT</h1>")
24
-
25
- input_box = gr.Textbox(label="Input Text", placeholder="Enter text to analyze sentiment")
26
- output_box = gr.Textbox(label="Sentiment Result", placeholder="Sentiment result will appear here")
27
-
28
- submit_button = gr.Button("Analyze Sentiment")
29
-
30
- # Button click event
31
- submit_button.click(fn=predict_sentiment, inputs=input_box, outputs=output_box)
32
-
33
- # Launch the app
34
- if __name__ == "__main__":
35
- sentiment_app.launch()
 
1
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
2
  import torch
 
3
 
4
+ # Load the T5 tokenizer and model
5
+ model_name = "t5-small" # You can use any T5 model available
6
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
7
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
8
 
9
+ # Example function to use the model
10
+ def summarize(text):
11
+ # Tokenize the input text
12
+ inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
13
+ # Generate summary
14
+ outputs = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
15
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
16
+ return summary
17
 
18
+ # Example usage
19
+ text_to_summarize = "Your input text goes here."
20
+ print(summarize(text_to_summarize))