import gradio as gr import pickle import numpy as np import glob import tqdm import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel from peft import PeftModel from tevatron.retriever.searcher import FaissFlatSearcher import logging import os import json import spaces import ir_datasets import pytrec_eval from huggingface_hub import login # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Authenticate with HF_TOKEN login(token=os.environ['HF_TOKEN']) # Global variables CUR_MODEL = "orionweller/repllama-instruct-hard-positives-v2-joint" BASE_MODEL = "meta-llama/Llama-2-7b-hf" tokenizer = None model = None retrievers = {} corpus_lookups = {} queries = {} q_lookups = {} qrels = {} datasets = ["scifact", "arguana"] current_dataset = "scifact" def pool(last_hidden_states, attention_mask): last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden.shape[0] return last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] def create_batch_dict(tokenizer, input_texts, max_length=512): batch_dict = tokenizer( input_texts, max_length=max_length - 1, return_token_type_ids=False, return_attention_mask=False, padding=False, truncation=True ) batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']] return tokenizer.pad( batch_dict, padding=True, pad_to_multiple_of=8, return_attention_mask=True, return_tensors="pt", ) def load_model(): global tokenizer, model tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" base_model_instance = AutoModel.from_pretrained(BASE_MODEL) model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL) model = model.merge_and_unload() model.eval() model.cuda() def load_corpus_embeddings(dataset_name): global retrievers, corpus_lookups corpus_path = f"{dataset_name}/corpus_emb.*.pkl" index_files = glob.glob(corpus_path) logger.info(f'Loading {len(index_files)} files into index for {dataset_name}.') p_reps_0, p_lookup_0 = pickle_load(index_files[0]) retrievers[dataset_name] = FaissFlatSearcher(p_reps_0) shards = [(p_reps_0, p_lookup_0)] + [pickle_load(f) for f in index_files[1:]] corpus_lookups[dataset_name] = [] for p_reps, p_lookup in tqdm.tqdm(shards, desc=f'Loading shards into index for {dataset_name}', total=len(index_files)): retrievers[dataset_name].add(p_reps) corpus_lookups[dataset_name] += p_lookup def pickle_load(path): with open(path, 'rb') as f: reps, lookup = pickle.load(f) return np.array(reps), lookup def load_queries(dataset_name): global queries, q_lookups, qrels dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else "")) queries[dataset_name] = [] q_lookups[dataset_name] = {} qrels[dataset_name] = {} for query in dataset.queries_iter(): queries[dataset_name].append(query.text) q_lookups[dataset_name][query.query_id] = query.text for qrel in dataset.qrels_iter(): if qrel.query_id not in qrels[dataset_name]: qrels[dataset_name][qrel.query_id] = {} qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance @spaces.GPU def encode_queries(dataset_name, postfix): global queries, tokenizer, model model = model.cuda() input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[dataset_name]] encoded_embeds = [] batch_size = 32 for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc="Encoding queries"): batch_input_texts = input_texts[start_idx: start_idx + batch_size] batch_dict = create_batch_dict(tokenizer, batch_input_texts) batch_dict = {k: v.to(model.device) for k, v in batch_dict.items()} with torch.cuda.amp.autocast(): outputs = model(**batch_dict) embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask']) embeds = F.normalize(embeds, p=2, dim=-1) encoded_embeds.append(embeds.cpu().numpy()) return np.concatenate(encoded_embeds, axis=0) def search_queries(dataset_name, q_reps, depth=1000): all_scores, all_indices = retrievers[dataset_name].search(q_reps, depth) psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices] return all_scores, np.array(psg_indices) def evaluate(qrels, results, k_values): evaluator = pytrec_eval.RelevanceEvaluator( qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values} ) scores = evaluator.evaluate(results) metrics = {} for k in k_values: metrics[f"NDCG@{k}"] = round(np.mean([query_scores[f"ndcg_cut_{k}"] for query_scores in scores.values()]), 3) metrics[f"Recall@{k}"] = round(np.mean([query_scores[f"recall_{k}"] for query_scores in scores.values()]), 3) return metrics def run_evaluation(dataset, postfix): global current_dataset if dataset not in retrievers or dataset not in queries: load_corpus_embeddings(dataset) load_queries(dataset) current_dataset = dataset q_reps = encode_queries(dataset, postfix) all_scores, psg_indices = search_queries(dataset, q_reps) results = {qid: dict(zip(doc_ids, map(float, scores))) for qid, scores, doc_ids in zip(q_lookups[dataset].keys(), all_scores, psg_indices)} metrics = evaluate(qrels[dataset], results, k_values=[10, 100]) return { "NDCG@10": metrics["NDCG@10"], "Recall@100": metrics["Recall@100"] } def gradio_interface(dataset, postfix): return run_evaluation(dataset, postfix) # Load model and initial datasets load_model() for dataset in datasets: print(f"Loading dataset: {dataset}") load_corpus_embeddings(dataset) load_queries(dataset) # Create Gradio interface iface = gr.Interface( fn=gradio_interface, inputs=[ gr.Dropdown(choices=datasets, label="Dataset", value="scifact"), gr.Textbox(label="Prompt") ], outputs=gr.JSON(label="Evaluation Results"), title="Promptriever Demo", description="Select a dataset and enter a postfix prompt to evaluate the model's performance. Note: it takes about **ten seconds** for each dataset." ) # Launch the interface iface.launch()