File size: 1,997 Bytes
cd97512 315aebc cd97512 7c2bec0 cd97512 7c2bec0 315aebc 7c2bec0 cd97512 7c2bec0 cd97512 7c2bec0 315aebc cd97512 7c2bec0 315aebc cd97512 315aebc cd97512 315aebc cd97512 315aebc cd97512 7c2bec0 cd97512 315aebc cd97512 315aebc cd97512 315aebc cd97512 315aebc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
---
library_name: transformers
license: apache-2.0
---
### Model Description
<!-- Provide a longer summary of what this model is. -->
Query Rewriting in Retrieval-Augmented Large Language Models
Arxiv : https://arxiv.org/abs/2305.14283
Large Language Models (LLMs) play powerful, black-box readers in the retrieve-then-read pipeline, making remarkable progress in knowledge-intensive tasks. This work introduces a new framework, Rewrite-Retrieve-Read instead of the previous retrieve-then-read for the retrieval-augmented LLMs from the perspective of the query rewriting. We first prompt an LLM to generate the query, then use a web search engine to retrieve contexts. Furthermore, to better align the query to the frozen modules, we propose a trainable scheme for our pipeline. A small language model is adopted as a trainable rewriter to cater to the black-box LLM reader. The rewriter is trained using the feedback of the LLM reader by reinforcement learning.
- **Developed by:** https://github.com/xbmxb/RAG-query-rewriting
- **Model type:** google/t5-large
- **Checkpoint:** checkpoint_20
### Inference
```
from transformers import T5Tokenizer,T5ForConditionalGeneration,BitsAndBytesConfig
import torch
# 8 bit Quantization
quantization_config = BitsAndBytesConfig(
load_in_8bit=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained('catyung/t5l-turbo-hotpot-0331',
quantization_config=quantization_config)
tokenizer = T5Tokenizer.from_pretrained('catyung/t5l-turbo-hotpot-0331')
rewrite_prompt = f"""rewrite a better search query: {user_query}
answer:"""
# Inference
user_query = "What profession does Nicholas Ray and Elia Kazan have in common?"
input_ids = tokenizer(rewrite_prompt, return_tensors="pt").input_ids.to(device)
outputs = model.generate(input_ids,max_new_tokens=50)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)
```
|