File size: 1,997 Bytes
cd97512
 
315aebc
cd97512
 
 
 
 
 
7c2bec0
cd97512
7c2bec0
 
 
315aebc
 
7c2bec0
cd97512
7c2bec0
cd97512
7c2bec0
 
315aebc
cd97512
7c2bec0
315aebc
 
cd97512
315aebc
cd97512
315aebc
 
cd97512
315aebc
cd97512
7c2bec0
 
 
 
 
cd97512
315aebc
cd97512
315aebc
cd97512
315aebc
cd97512
315aebc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
---
library_name: transformers
license: apache-2.0
---

### Model Description

<!-- Provide a longer summary of what this model is. -->

Query Rewriting in Retrieval-Augmented Large Language Models 

Arxiv : https://arxiv.org/abs/2305.14283

Large Language Models (LLMs) play powerful, black-box readers in the retrieve-then-read pipeline, making remarkable progress in knowledge-intensive tasks. This work introduces a new framework, Rewrite-Retrieve-Read instead of the previous retrieve-then-read for the retrieval-augmented LLMs from the perspective of the query rewriting. We first prompt an LLM to generate the query, then use a web search engine to retrieve contexts. Furthermore, to better align the query to the frozen modules, we propose a trainable scheme for our pipeline. A small language model is adopted as a trainable rewriter to cater to the black-box LLM reader. The rewriter is trained using the feedback of the LLM reader by reinforcement learning.
- **Developed by:** https://github.com/xbmxb/RAG-query-rewriting
- **Model type:** google/t5-large
- **Checkpoint:** checkpoint_20

### Inference

```
from transformers import T5Tokenizer,T5ForConditionalGeneration,BitsAndBytesConfig 
import torch

# 8 bit Quantization
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = T5ForConditionalGeneration.from_pretrained('catyung/t5l-turbo-hotpot-0331',
                                quantization_config=quantization_config)

tokenizer = T5Tokenizer.from_pretrained('catyung/t5l-turbo-hotpot-0331')

rewrite_prompt = f"""rewrite a better search query: {user_query}
answer:"""

# Inference
user_query = "What profession does Nicholas Ray and Elia Kazan have in common?"

input_ids = tokenizer(rewrite_prompt, return_tensors="pt").input_ids.to(device)

outputs = model.generate(input_ids,max_new_tokens=50)

result = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(result)
```