import os import time import threading import queue from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import gradio as gr # Load the tokenizer and model from Hugging Face model_name = "andreas122001/roberta-mixed-detector" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) # Environment variable for API key API_KEY = os.getenv('API_KEY') # Queues for requests authenticated_queue = queue.Queue() non_authenticated_queue = queue.Queue() def classify_text(text): # Tokenize the input text inputs = tokenizer(text, return_tensors="pt") # Perform inference with torch.no_grad(): outputs = model(**inputs) # Get the predicted label predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_label = torch.argmax(predictions, dim=1).item() # Get the confidence score confidence_score = predictions[0][predicted_label].item() return predicted_label, confidence_score def process_queue(): while True: # Process authenticated queue first if not authenticated_queue.empty(): request = authenticated_queue.get() label, score = classify_text(request['text']) request['response'].append((label, score)) # Check non-authenticated queue elif not non_authenticated_queue.empty(): request = non_authenticated_queue.get() # If the request has been waiting for more than 15 seconds, promote it if time.time() - request['timestamp'] > 15: label, score = classify_text(request['text']) request['response'].append((label, score)) else: # Put it back in the queue if there are authenticated requests authenticated_queue.put(request) time.sleep(0.1) # Adjust the sleep time as necessary def queue_request(text, api_key): response = [] request = { 'text': text, 'timestamp': time.time(), 'response': response } if api_key == API_KEY: authenticated_queue.put(request) else: non_authenticated_queue.put(request) while not response: time.sleep(0.1) # Wait for the response to be processed return response[0] def get_queue_sizes(): return len(authenticated_queue.queue), len(non_authenticated_queue.queue) # Start the queue processing thread threading.Thread(target=process_queue, daemon=True).start() # Create the Gradio interface interface = gr.Interface( fn=queue_request, inputs=[ gr.Textbox(lines=2, placeholder="Enter text here..."), gr.Textbox(lines=1, placeholder="Enter API key here...", type="password") ], outputs=[gr.Label(label="Predicted Label"), gr.Number(label="Confidence Score")], title="Text Classification with roberta-mixed-detector", description="A simple app to classify text using the roberta-mixed-detector model with priority queuing." ) queue_interface = gr.Interface( fn=get_queue_sizes, inputs=[], outputs=[gr.Number(label="Authenticated Queue Size"), gr.Number(label="Non-Authenticated Queue Size")], live=True ) # Combine the interfaces combined_interface = gr.TabbedInterface([interface, queue_interface], ["Classify Text", "Queue Sizes"]) # Launch the app combined_interface.launch()