Spaces:
Sleeping
Sleeping
from llama_index.core.query_engine import CustomQueryEngine | |
from llama_index.core.retrievers import BaseRetriever | |
from llama_index.multi_modal_llms.openai import OpenAIMultiModal | |
from llama_index.core.schema import ImageNode, NodeWithScore, MetadataMode | |
from llama_index.core.prompts import PromptTemplate | |
from llama_index.core.base.response.schema import Response | |
from typing import Optional | |
from core.prompt import MULTOMODAL_QUERY_TEMPLATE | |
gpt_4o = OpenAIMultiModal(model="gpt-4o-mini", max_new_tokens=4096) | |
QA_PROMPT = PromptTemplate(MULTOMODAL_QUERY_TEMPLATE) | |
class MultimodalQueryEngine(CustomQueryEngine): | |
"""Custom multimodal Query Engine. | |
Takes in a retriever to retrieve a set of document nodes. | |
Also takes in a prompt template and multimodal model. | |
""" | |
qa_prompt: PromptTemplate | |
retriever: BaseRetriever | |
multi_modal_llm: OpenAIMultiModal | |
def __init__(self, qa_prompt: Optional[PromptTemplate] = None, **kwargs) -> None: | |
"""Initialize.""" | |
super().__init__(qa_prompt=qa_prompt or QA_PROMPT, **kwargs) | |
def custom_query(self, query_str: str): | |
# retrieve text nodes | |
nodes = self.retriever.retrieve(query_str) | |
# create ImageNode items from text nodes | |
image_nodes = [ | |
NodeWithScore(node=ImageNode(image_url=link)) | |
for n in nodes | |
if "image_link" in n.metadata | |
and n.metadata["image_link"] not in ["", []] | |
for link in (n.metadata["image_link"] if isinstance(n.metadata["image_link"], list) else [n.metadata["image_link"]]) | |
if link not in ["", []] | |
] | |
print("image_nodes: {}".format(image_nodes)) | |
# create context string from text nodes, dump into the prompt | |
context_str = "\n\n".join( | |
[r.get_content(metadata_mode=MetadataMode.LLM) for r in nodes] | |
) | |
fmt_prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) | |
# synthesize an answer from formatted text and images | |
llm_response = self.multi_modal_llm.complete( | |
prompt=fmt_prompt, | |
image_documents=[image_node.node for image_node in image_nodes], | |
) | |
return Response( | |
response=str(llm_response), | |
source_nodes=nodes, | |
metadata={"text_nodes": nodes, "image_nodes": image_nodes}, | |
) | |