SharmaAmit1818 commited on
Commit
1158018
·
verified ·
1 Parent(s): 4da0f71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -9,26 +9,29 @@ model = BertForSequenceClassification.from_pretrained('huawei-noah/TinyBERT_Gene
9
 
10
  # Function to process the CSV file and generate predictions
11
  def process_csv(file):
12
- # Read the CSV file
13
- df = pd.read_csv(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Ensure the CSV has a 'text' column
16
- if 'text' not in df.columns:
17
- return "Error: The CSV file must contain a 'text' column."
18
-
19
- # Tokenize the input text
20
- inputs = tokenizer(df['text'].tolist(), return_tensors='pt', padding=True, truncation=True)
21
-
22
- # Perform inference
23
- with torch.no_grad():
24
- outputs = model(**inputs)
25
-
26
- # Get predicted classes
27
- _, predicted_classes = torch.max(outputs.logits, dim=1)
28
- df['predicted_class'] = predicted_classes.numpy()
29
-
30
- # Return the processed DataFrame as a CSV string
31
- return df.to_csv(index=False)
32
 
33
  # Create the Gradio interface
34
  input_csv = gr.File(label="Upload CSV File")
 
9
 
10
  # Function to process the CSV file and generate predictions
11
  def process_csv(file):
12
+ try:
13
+ # Read the CSV file
14
+ df = pd.read_csv(file.name) # Use file.name to get the file path
15
+ # Ensure the CSV has a 'text' column
16
+ if 'text' not in df.columns:
17
+ return "Error: The CSV file must contain a 'text' column."
18
+
19
+ # Tokenize the input text
20
+ inputs = tokenizer(df['text'].tolist(), return_tensors='pt', padding=True, truncation=True)
21
+
22
+ # Perform inference
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+
26
+ # Get predicted classes
27
+ _, predicted_classes = torch.max(outputs.logits, dim=1)
28
+ df['predicted_class'] = predicted_classes.numpy()
29
+
30
+ # Return the processed DataFrame as a CSV string
31
+ return df.to_csv(index=False)
32
 
33
+ except Exception as e:
34
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Create the Gradio interface
37
  input_csv = gr.File(label="Upload CSV File")