oscarwang2 commited on
Commit
0a036bb
1 Parent(s): c385be2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -19
app.py CHANGED
@@ -1,37 +1,79 @@
1
  import gradio as gr
2
  from transformers import RobertaTokenizer, RobertaForSequenceClassification
3
  import torch
 
 
 
4
 
5
  # Load the model and tokenizer from the specified directory
6
  model_path = './finetuned_roberta'
7
  tokenizer = RobertaTokenizer.from_pretrained(model_path)
8
  model = RobertaForSequenceClassification.from_pretrained(model_path)
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Define the prediction function
11
  def classify_text(text):
12
- # Tokenize the input text
13
- inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
14
 
15
- # Get the model's prediction
16
- with torch.no_grad():
17
- outputs = model(**inputs)
18
-
19
- # Apply softmax to get probabilities
20
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Get the probability of the class '1'
23
- prob_1 = probabilities[0][1].item()
 
 
 
 
 
 
24
 
25
- return {"Probability of being 1": prob_1}
26
 
27
- # Create the Gradio interface
28
- iface = gr.Interface(
29
- fn=classify_text,
30
- inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
31
- outputs="json",
32
- title="Text Classification with RoBERTa",
33
- description="Enter some text and get the probability of the text being classified as class 1.",
34
- )
 
35
 
36
  # Launch the app
37
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from transformers import RobertaTokenizer, RobertaForSequenceClassification
3
  import torch
4
+ import psutil
5
+ import threading
6
+ from queue import Queue
7
 
8
  # Load the model and tokenizer from the specified directory
9
  model_path = './finetuned_roberta'
10
  tokenizer = RobertaTokenizer.from_pretrained(model_path)
11
  model = RobertaForSequenceClassification.from_pretrained(model_path)
12
 
13
+ # Initialize a request queue with a maximum of 2 concurrent requests
14
+ request_queue = Queue(maxsize=2)
15
+
16
+ # Function to get CPU and RAM usage
17
+ def get_system_usage():
18
+ cpu_usage = psutil.cpu_percent(interval=1)
19
+ ram_usage = psutil.virtual_memory().percent
20
+ return f"CPU Usage: {cpu_usage}%", f"RAM Usage: {ram_usage}%"
21
+
22
+ # Function to get the user's position in the queue
23
+ def get_queue_position():
24
+ return f"Queue Position: {request_queue.qsize() + 1}"
25
+
26
  # Define the prediction function
27
  def classify_text(text):
28
+ request_queue.put(1) # Add request to queue
29
+ position_in_queue = get_queue_position()
30
 
31
+ while request_queue.full():
32
+ position_in_queue = get_queue_position()
33
+
34
+ try:
35
+ # Tokenize the input text
36
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
37
+
38
+ # Get the model's prediction
39
+ with torch.no_grad():
40
+ outputs = model(**inputs)
41
+
42
+ # Apply softmax to get probabilities
43
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
44
+
45
+ # Get the probability of the class '1'
46
+ prob_1 = probabilities[0][1].item()
47
+
48
+ return {"Probability of being 1": prob_1, "Queue Position": position_in_queue}
49
+ finally:
50
+ request_queue.get() # Remove request from queue
51
+
52
+ # Create the Gradio interface
53
+ with gr.Blocks() as iface:
54
+ with gr.Row():
55
+ gr.Markdown("### Text Classification with RoBERTa")
56
 
57
+ with gr.Row():
58
+ with gr.Column():
59
+ input_text = gr.Textbox(lines=2, placeholder="Enter text here...")
60
+ classify_btn = gr.Button("Classify")
61
+ with gr.Column():
62
+ cpu_output = gr.Markdown("")
63
+ ram_output = gr.Markdown("")
64
+ queue_output = gr.Markdown("")
65
 
66
+ output_json = gr.JSON()
67
 
68
+ def update_usage():
69
+ while True:
70
+ cpu_usage, ram_usage = get_system_usage()
71
+ cpu_output.update(cpu_usage)
72
+ ram_output.update(ram_usage)
73
+
74
+ threading.Thread(target=update_usage, daemon=True).start()
75
+
76
+ classify_btn.click(classify_text, inputs=input_text, outputs=output_json)
77
 
78
  # Launch the app
79
  if __name__ == "__main__":