Adrien
feat: add inline sources and query classifier
cc3f1e1
raw
history blame contribute delete
874 Bytes
import os
from typing import Any
from openai import OpenAI
from rag_demo.rag.base.query import Query
from rag_demo.rag.base.template_factory import RAGStep
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from loguru import logger
import torch
model_name = (
"AdrienB134/greetings-classifier" # Model trained on English greetings only
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
class QueryClassifier(RAGStep):
def generate(self, query: Query) -> Any:
if self._mock:
return "Sources_needed"
with torch.no_grad():
inputs = tokenizer(query.content, return_tensors="pt")
logits = model(**inputs).logits
predictions = logits.argmax()
return model.config.id2label[predictions.item()]