sashdev commited on
Commit
5d50a44
·
verified ·
1 Parent(s): b91fdea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -28
app.py CHANGED
@@ -1,45 +1,35 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration
4
 
5
- # Load T5 model and tokenizer
6
- model_name = "t5-base" # Use a smaller model for faster inference
7
- tokenizer = T5Tokenizer.from_pretrained(model_name)
8
- model = T5ForConditionalGeneration.from_pretrained(model_name)
9
 
10
  # Use GPU if available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
14
- # Grammar correction function
15
- def correct_grammar(text):
16
- input_text = f"correct: {text}"
17
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
18
-
19
- # Generate corrected text
20
- output_ids = model.generate(input_ids, max_length=512, num_beams=5, early_stopping=True)
21
- corrected_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
22
-
23
- return corrected_text
24
-
25
- # Gradio interface function
26
- def correct_grammar_interface(text):
27
- corrected_text = correct_grammar(text)
28
- return corrected_text
29
 
30
  # Gradio interface
31
- with gr.Blocks() as grammar_app:
32
- gr.Markdown("<h1>Fast Grammar Correction with T5</h1>")
33
 
34
- with gr.Row():
35
- input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
36
- output_box = gr.Textbox(label="Corrected Text", placeholder="Corrected text will appear here", lines=4)
37
 
38
- submit_button = gr.Button("Correct Grammar")
39
 
40
  # Button click event
41
- submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
42
 
43
  # Launch the app
44
  if __name__ == "__main__":
45
- grammar_app.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
 
5
+ # Load DistilBERT model and tokenizer
6
+ model_name = "bhadresh-savani/distilbert-base-uncased-finetuned-sentiment"
7
+ tokenizer = DistilBertTokenizer.from_pretrained(model_name)
8
+ model = DistilBertForSequenceClassification.from_pretrained(model_name)
9
 
10
  # Use GPU if available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
14
+ # Define the prediction function
15
+ def predict_sentiment(text):
16
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
17
+ outputs = model(**inputs)
18
+ predictions = torch.argmax(outputs.logits, dim=-1)
19
+ return predictions.item()
 
 
 
 
 
 
 
 
 
20
 
21
  # Gradio interface
22
+ with gr.Blocks() as sentiment_app:
23
+ gr.Markdown("<h1>Sentiment Analysis with DistilBERT</h1>")
24
 
25
+ input_box = gr.Textbox(label="Input Text", placeholder="Enter text to analyze sentiment")
26
+ output_box = gr.Textbox(label="Sentiment Result", placeholder="Sentiment result will appear here")
 
27
 
28
+ submit_button = gr.Button("Analyze Sentiment")
29
 
30
  # Button click event
31
+ submit_button.click(fn=predict_sentiment, inputs=input_box, outputs=output_box)
32
 
33
  # Launch the app
34
  if __name__ == "__main__":
35
+ sentiment_app.launch()