library_name: transformers
license: apache-2.0
Model Description
Query Rewriting in Retrieval-Augmented Large Language Models
Arxiv : https://arxiv.org/abs/2305.14283
Large Language Models (LLMs) play powerful, black-box readers in the retrieve-then-read pipeline, making remarkable progress in knowledge-intensive tasks. This work introduces a new framework, Rewrite-Retrieve-Read instead of the previous retrieve-then-read for the retrieval-augmented LLMs from the perspective of the query rewriting. We first prompt an LLM to generate the query, then use a web search engine to retrieve contexts. Furthermore, to better align the query to the frozen modules, we propose a trainable scheme for our pipeline. A small language model is adopted as a trainable rewriter to cater to the black-box LLM reader. The rewriter is trained using the feedback of the LLM reader by reinforcement learning.
- Developed by: https://github.com/xbmxb/RAG-query-rewriting
- Model type: google/t5-large
- Checkpoint: checkpoint_20
Inference
from transformers import T5Tokenizer,T5ForConditionalGeneration,BitsAndBytesConfig
import torch
# 8 bit Quantization
quantization_config = BitsAndBytesConfig(
load_in_8bit=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained('catyung/t5l-turbo-hotpot-0331',
quantization_config=quantization_config)
tokenizer = T5Tokenizer.from_pretrained('catyung/t5l-turbo-hotpot-0331')
rewrite_prompt = f"""rewrite a better search query: {user_query}
answer:"""
# Inference
user_query = "What profession does Nicholas Ray and Elia Kazan have in common?"
input_ids = tokenizer(rewrite_prompt, return_tensors="pt").input_ids.to(device)
outputs = model.generate(input_ids,max_new_tokens=50)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)