willco-afk commited on
Commit
8a77d11
·
verified ·
1 Parent(s): 80b97e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -2,21 +2,26 @@ import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
4
 
5
- # Load pre-trained model and tokenizer
6
- model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
7
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
8
 
9
- # Define a function to make predictions using the model
10
- def predict(text):
11
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
12
- with torch.no_grad():
13
- outputs = model(**inputs)
14
- logits = outputs.logits
15
- predicted_class = torch.argmax(logits, dim=-1).item()
16
- return predicted_class
17
 
18
- # Create Gradio interface
19
- iface = gr.Interface(fn=predict, inputs="text", outputs="text", live=True)
 
 
 
 
 
20
 
21
- # Launch the Gradio app
22
- iface.launch()
 
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
4
 
5
+ # Load the model and tokenizer
6
+ model_name = "willco-afk/my-model-name"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
+ def classify_text(text):
11
+ # Preprocess text
12
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
13
+ outputs = model(**inputs)
14
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
15
+ label = torch.argmax(probs, dim=1).item()
16
+ labels = ["english", "spanish", "tagalog"]
17
+ return labels[label]
18
 
19
+ # Define the Gradio interface
20
+ with gr.Blocks() as demo:
21
+ gr.Markdown("# Slang Translation Classifier")
22
+ input_text = gr.Textbox(label="Enter slang text", lines=1)
23
+ output_label = gr.Textbox(label="Predicted Language", interactive=False)
24
+ submit_button = gr.Button("Classify")
25
+ submit_button.click(classify_text, inputs=[input_text], outputs=[output_label])
26
 
27
+ demo.launch()