catyung's picture
Update README.md
7c2bec0 verified
metadata
library_name: transformers
license: apache-2.0

Model Description

Query Rewriting in Retrieval-Augmented Large Language Models

Arxiv : https://arxiv.org/abs/2305.14283

Large Language Models (LLMs) play powerful, black-box readers in the retrieve-then-read pipeline, making remarkable progress in knowledge-intensive tasks. This work introduces a new framework, Rewrite-Retrieve-Read instead of the previous retrieve-then-read for the retrieval-augmented LLMs from the perspective of the query rewriting. We first prompt an LLM to generate the query, then use a web search engine to retrieve contexts. Furthermore, to better align the query to the frozen modules, we propose a trainable scheme for our pipeline. A small language model is adopted as a trainable rewriter to cater to the black-box LLM reader. The rewriter is trained using the feedback of the LLM reader by reinforcement learning.

Inference

from transformers import T5Tokenizer,T5ForConditionalGeneration,BitsAndBytesConfig 
import torch

# 8 bit Quantization
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = T5ForConditionalGeneration.from_pretrained('catyung/t5l-turbo-hotpot-0331',
                                quantization_config=quantization_config)

tokenizer = T5Tokenizer.from_pretrained('catyung/t5l-turbo-hotpot-0331')

rewrite_prompt = f"""rewrite a better search query: {user_query}
answer:"""

# Inference
user_query = "What profession does Nicholas Ray and Elia Kazan have in common?"

input_ids = tokenizer(rewrite_prompt, return_tensors="pt").input_ids.to(device)

outputs = model.generate(input_ids,max_new_tokens=50)

result = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(result)