Spaces:
Sleeping
Sleeping
File size: 8,073 Bytes
170741d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
import time
from llama_index.core import VectorStoreIndex
from llama_index.core.query_pipeline import (
QueryPipeline,
InputComponent,
ArgPackComponent,
)
from llama_index.core.prompts import PromptTemplate
from llama_index.llms.openai import OpenAI
from llama_index.postprocessor.colbert_rerank import ColbertRerank
from typing import Any, Dict, List, Optional
from llama_index.core.bridge.pydantic import Field
from llama_index.core.llms import ChatMessage
from llama_index.core.query_pipeline import CustomQueryComponent
from llama_index.core.schema import NodeWithScore
from llama_index.core.memory import ChatMemoryBuffer
llm = OpenAI(
model="gpt-3.5-turbo-0125",
api_key=os.getenv("OPENAI_API_KEY"),
)
# First, we create an input component to capture the user query
input_component = InputComponent()
# Next, we use the LLM to rewrite a user query
rewrite = (
"Please write a query to a semantic search engine using the current conversation.\n"
"\n"
"\n"
"{chat_history_str}"
"\n"
"\n"
"Latest message: {query_str}\n"
'Query:"""\n'
)
rewrite_template = PromptTemplate(rewrite)
# we will retrieve two times, so we need to pack the retrieved nodes into a single list
argpack_component = ArgPackComponent()
# then postprocess/rerank with Colbert
reranker = ColbertRerank(top_n=3)
DEFAULT_CONTEXT_PROMPT = (
"Here is some context that may be relevant:\n"
"-----\n"
"{node_context}\n"
"-----\n"
"Please write a response to the following question, using the above context:\n"
"{query_str}\n"
"Please formate your response in the following way:\n"
"Your answer here.\n"
"Reference:\n"
" Your references here (e.g. page numbers, titles, etc.).\n"
)
class ResponseWithChatHistory(CustomQueryComponent):
llm: OpenAI = Field(..., description="OpenAI LLM")
system_prompt: Optional[str] = Field(
default=None, description="System prompt to use for the LLM"
)
context_prompt: str = Field(
default=DEFAULT_CONTEXT_PROMPT,
description="Context prompt to use for the LLM",
)
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# NOTE: this is OPTIONAL but we show you where to do validation as an example
return input
@property
def _input_keys(self) -> set:
"""Input keys dict."""
# NOTE: These are required inputs. If you have optional inputs please override
# `optional_input_keys_dict`
return {"chat_history", "nodes", "query_str"}
@property
def _output_keys(self) -> set:
return {"response"}
def _prepare_context(
self,
chat_history: List[ChatMessage],
nodes: List[NodeWithScore],
query_str: str,
) -> List[ChatMessage]:
node_context = ""
for idx, node in enumerate(nodes):
node_text = node.get_content(metadata_mode="llm")
node_context += f"Context Chunk {idx}:\n{node_text}\n\n"
formatted_context = self.context_prompt.format(
node_context=node_context, query_str=query_str
)
user_message = ChatMessage(role="user", content=formatted_context)
chat_history.append(user_message)
if self.system_prompt is not None:
chat_history = [
ChatMessage(role="system", content=self.system_prompt)
] + chat_history
return chat_history
def _run_component(self, **kwargs) -> Dict[str, Any]:
"""Run the component."""
chat_history = kwargs["chat_history"]
nodes = kwargs["nodes"]
query_str = kwargs["query_str"]
prepared_context = self._prepare_context(chat_history, nodes, query_str)
response = llm.chat(prepared_context)
return {"response": response}
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
"""Run the component asynchronously."""
# NOTE: Optional, but async LLM calls are easy to implement
chat_history = kwargs["chat_history"]
nodes = kwargs["nodes"]
query_str = kwargs["query_str"]
prepared_context = self._prepare_context(chat_history, nodes, query_str)
response = await llm.achat(prepared_context)
return {"response": response}
class LlamaCustomV2:
response_component = ResponseWithChatHistory(
llm=llm,
system_prompt=(
"You are a Q&A system. You will be provided with the previous chat history, "
"as well as possibly relevant context, to assist in answering a user message."
),
)
def __init__(self, model_name: str, index: VectorStoreIndex):
self.model_name = model_name
self.index = index
self.retriever = index.as_retriever()
self.chat_mode = "condense_plus_context"
self.memory = ChatMemoryBuffer.from_defaults()
self.verbose = True
self._build_pipeline()
def _build_pipeline(self):
self.pipeline = QueryPipeline(
modules={
"input": input_component,
"rewrite_template": rewrite_template,
"llm": llm,
"rewrite_retriever": self.retriever,
"query_retriever": self.retriever,
"join": argpack_component,
"reranker": reranker,
"response_component": self.response_component,
},
verbose=self.verbose,
)
# run both retrievers -- once with the hallucinated query, once with the real query
self.pipeline.add_link(
"input", "rewrite_template", src_key="query_str", dest_key="query_str"
)
self.pipeline.add_link(
"input",
"rewrite_template",
src_key="chat_history_str",
dest_key="chat_history_str",
)
self.pipeline.add_link("rewrite_template", "llm")
self.pipeline.add_link("llm", "rewrite_retriever")
self.pipeline.add_link("input", "query_retriever", src_key="query_str")
# each input to the argpack component needs a dest key -- it can be anything
# then, the argpack component will pack all the inputs into a single list
self.pipeline.add_link("rewrite_retriever", "join", dest_key="rewrite_nodes")
self.pipeline.add_link("query_retriever", "join", dest_key="query_nodes")
# reranker needs the packed nodes and the query string
self.pipeline.add_link("join", "reranker", dest_key="nodes")
self.pipeline.add_link(
"input", "reranker", src_key="query_str", dest_key="query_str"
)
# synthesizer needs the reranked nodes and query str
self.pipeline.add_link("reranker", "response_component", dest_key="nodes")
self.pipeline.add_link(
"input", "response_component", src_key="query_str", dest_key="query_str"
)
self.pipeline.add_link(
"input",
"response_component",
src_key="chat_history",
dest_key="chat_history",
)
def get_response(self, query_str: str, chat_history: List[ChatMessage]):
chat_history = self.memory.get()
char_history_str = "\n".join([str(x) for x in chat_history])
response = self.pipeline.run(
query_str=query_str,
chat_history=chat_history,
chat_history_str=char_history_str,
)
user_msg = ChatMessage(role="user", content=query_str)
print("user_msg: ", str(user_msg))
print("response: ", str(response.message))
self.memory.put(user_msg)
self.memory.put(response.message)
return str(response.message)
def get_stream_response(self, query_str: str, chat_history: List[ChatMessage]):
response = self.get_response(query_str=query_str, chat_history=chat_history)
for word in response.split():
yield word + " "
time.sleep(0.05)
|