Spaces:
Sleeping
Sleeping
File size: 4,641 Bytes
cfb1a62 8121eff 6a076b8 cfb1a62 bc5a5b2 f4b7267 8121eff cfb1a62 b117341 669d93a 122cee1 8121eff 669d93a 8121eff 669d93a 8121eff 669d93a 8121eff 669d93a b117341 669d93a 8121eff 122cee1 5f52091 6a076b8 9f2191f 6a076b8 b117341 8121eff b117341 cfb1a62 b117341 122cee1 5f52091 9f2191f b117341 6a076b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# rag/rag_pipeline.py
import json
from typing import Dict, Any
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
from llama_index.core import PromptTemplate
from typing import List
class RAGPipeline:
def __init__(self, study_json, use_semantic_splitter=False):
self.study_json = study_json
self.use_semantic_splitter = use_semantic_splitter
self.documents = None
self.index = None
self.load_documents()
self.build_index()
def load_documents(self):
if self.documents is None:
with open(self.study_json, "r") as f:
self.data = json.load(f)
self.documents = []
for index, doc_data in enumerate(self.data):
doc_content = (
f"Title: {doc_data['title']}\n"
f"Authors: {', '.join(doc_data['authors'])}\n"
f"Full Text: {doc_data['full_text']}"
)
metadata = {
"title": doc_data.get("title"),
"abstract": doc_data.get("abstract"),
"authors": doc_data.get("authors", []),
"year": doc_data.get("year"),
"doi": doc_data.get("doi"),
}
self.documents.append(
Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
)
def build_index(self):
if self.index is None:
sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13)
def _split(text: str) -> List[str]:
return sentence_splitter.split_text(text)
node_parser = SentenceWindowNodeParser.from_defaults(
sentence_splitter=_split,
window_size=3,
window_metadata_key="window",
original_text_metadata_key="original_text",
)
nodes = node_parser.get_nodes_from_documents(self.documents)
self.index = VectorStoreIndex(nodes)
def extract_study_info(self) -> Dict[str, Any]:
extraction_prompt = PromptTemplate(
"Based on the given context, please extract the following information about the study:\n"
"1. Study ID\n"
"2. Author(s)\n"
"3. Year\n"
"4. Title\n"
"5. Study design\n"
"6. Study area/region\n"
"7. Study population\n"
"8. Disease under study\n"
"9. Duration of study\n"
"If the information is not available, please respond with 'Not found' for that field.\n"
"Context: {context_str}\n"
"Extracted information:"
)
query_engine = self.index.as_query_engine(
text_qa_template=extraction_prompt, similarity_top_k=5
)
response = query_engine.query("Extract study information")
# Parse the response to extract key-value pairs
lines = response.response.split("\n")
extracted_info = {}
for line in lines:
if ":" in line:
key, value = line.split(":", 1)
extracted_info[key.strip().lower().replace(" ", "_")] = value.strip()
return extracted_info
def query(
self, question: str, prompt_template: PromptTemplate = None, **kwargs
) -> Dict[str, Any]:
if prompt_template is None:
prompt_template = PromptTemplate(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given this information, please answer the question: {query_str}\n"
"Include all relevant information from the provided context. "
"If information comes from multiple sources, please mention all of them. "
"If the information is not available in the context, please state that clearly. "
"When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
)
query_engine = self.index.as_query_engine(
text_qa_template=prompt_template, similarity_top_k=5
)
# Use kwargs to pass additional parameters to the query
response = query_engine.query(question, **kwargs)
return {
"question": question,
"answer": response.response,
"sources": [node.metadata for node in response.source_nodes],
}
|