import gradio as gr from transformers import RobertaTokenizer, RobertaForSequenceClassification import torch import psutil import threading from queue import Queue # Load the model and tokenizer from the specified directory model_path = './finetuned_roberta' tokenizer = RobertaTokenizer.from_pretrained(model_path) model = RobertaForSequenceClassification.from_pretrained(model_path) # Initialize a request queue with a maximum of 2 concurrent requests request_queue = Queue(maxsize=2) # Function to get CPU and RAM usage def get_system_usage(): cpu_usage = psutil.cpu_percent(interval=1) ram_usage = psutil.virtual_memory().percent return f"CPU Usage: {cpu_usage}%", f"RAM Usage: {ram_usage}%" # Function to get the user's position in the queue def get_queue_position(): return f"Queue Position: {request_queue.qsize() + 1}" # Define the prediction function def classify_text(text): request_queue.put(1) # Add request to queue position_in_queue = get_queue_position() while request_queue.full(): position_in_queue = get_queue_position() try: # Tokenize the input text inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128) # Get the model's prediction with torch.no_grad(): outputs = model(**inputs) # Apply softmax to get probabilities probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) # Get the probability of the class '1' prob_1 = probabilities[0][1].item() return {"Probability of being 1": prob_1, "Queue Position": position_in_queue} finally: request_queue.get() # Remove request from queue # Create the Gradio interface with gr.Blocks() as iface: with gr.Row(): gr.Markdown("### Text Classification with RoBERTa") with gr.Row(): with gr.Column(): input_text = gr.Textbox(lines=2, placeholder="Enter text here...") classify_btn = gr.Button("Classify") with gr.Column(): cpu_output = gr.Markdown("") ram_output = gr.Markdown("") queue_output = gr.Markdown("") output_json = gr.JSON() def update_usage(): while True: cpu_usage, ram_usage = get_system_usage() cpu_output.update(cpu_usage) ram_output.update(ram_usage) threading.Thread(target=update_usage, daemon=True).start() classify_btn.click(classify_text, inputs=input_text, outputs=output_json) # Launch the app if __name__ == "__main__": iface.launch(share=True)