Lalit1997 commited on
Commit
fbe6c18
·
verified ·
1 Parent(s): 779e42d

Update gen_ai.py

Browse files
Files changed (1) hide show
  1. gen_ai.py +3 -3
gen_ai.py CHANGED
@@ -19,14 +19,14 @@ label_dict = {
19
 
20
 
21
  class traditional_model:
22
- def __init__(self, query):
23
  self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
24
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
25
 
26
- def predict(self):
27
  self.model.to(device)
28
  self.model.eval()
29
- inputs = self.tokenizer(self.query, return_tensors="pt", truncation=True, padding=True).to(device) # Move input to device
30
  with torch.no_grad():
31
  outputs = self.model(**inputs)
32
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
 
19
 
20
 
21
  class traditional_model:
22
+ def __init__(self):
23
  self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
24
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
25
 
26
+ def predict(self,query):
27
  self.model.to(device)
28
  self.model.eval()
29
+ inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True).to(device) # Move input to device
30
  with torch.no_grad():
31
  outputs = self.model(**inputs)
32
  predicted_class = torch.argmax(outputs.logits, dim=1).item()