import os | |
from typing import Any | |
from openai import OpenAI | |
from rag_demo.rag.base.query import Query | |
from rag_demo.rag.base.template_factory import RAGStep | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from loguru import logger | |
import torch | |
model_name = ( | |
"AdrienB134/greetings-classifier" # Model trained on English greetings only | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
class QueryClassifier(RAGStep): | |
def generate(self, query: Query) -> Any: | |
if self._mock: | |
return "Sources_needed" | |
with torch.no_grad(): | |
inputs = tokenizer(query.content, return_tensors="pt") | |
logits = model(**inputs).logits | |
predictions = logits.argmax() | |
return model.config.id2label[predictions.item()] | |