oscarwang2's picture
Update app.py
0a036bb verified
raw
history blame
2.65 kB
import gradio as gr
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch
import psutil
import threading
from queue import Queue
# Load the model and tokenizer from the specified directory
model_path = './finetuned_roberta'
tokenizer = RobertaTokenizer.from_pretrained(model_path)
model = RobertaForSequenceClassification.from_pretrained(model_path)
# Initialize a request queue with a maximum of 2 concurrent requests
request_queue = Queue(maxsize=2)
# Function to get CPU and RAM usage
def get_system_usage():
cpu_usage = psutil.cpu_percent(interval=1)
ram_usage = psutil.virtual_memory().percent
return f"CPU Usage: {cpu_usage}%", f"RAM Usage: {ram_usage}%"
# Function to get the user's position in the queue
def get_queue_position():
return f"Queue Position: {request_queue.qsize() + 1}"
# Define the prediction function
def classify_text(text):
request_queue.put(1) # Add request to queue
position_in_queue = get_queue_position()
while request_queue.full():
position_in_queue = get_queue_position()
try:
# Tokenize the input text
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
# Get the model's prediction
with torch.no_grad():
outputs = model(**inputs)
# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get the probability of the class '1'
prob_1 = probabilities[0][1].item()
return {"Probability of being 1": prob_1, "Queue Position": position_in_queue}
finally:
request_queue.get() # Remove request from queue
# Create the Gradio interface
with gr.Blocks() as iface:
with gr.Row():
gr.Markdown("### Text Classification with RoBERTa")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(lines=2, placeholder="Enter text here...")
classify_btn = gr.Button("Classify")
with gr.Column():
cpu_output = gr.Markdown("")
ram_output = gr.Markdown("")
queue_output = gr.Markdown("")
output_json = gr.JSON()
def update_usage():
while True:
cpu_usage, ram_usage = get_system_usage()
cpu_output.update(cpu_usage)
ram_output.update(ram_usage)
threading.Thread(target=update_usage, daemon=True).start()
classify_btn.click(classify_text, inputs=input_text, outputs=output_json)
# Launch the app
if __name__ == "__main__":
iface.launch(share=True)