from llama_index.core.query_engine import CustomQueryEngine from llama_index.core.retrievers import BaseRetriever from llama_index.multi_modal_llms.openai import OpenAIMultiModal from llama_index.core.schema import ImageNode, NodeWithScore, MetadataMode from llama_index.core.prompts import PromptTemplate from llama_index.core.base.response.schema import Response from typing import Optional from core.prompt import MULTOMODAL_QUERY_TEMPLATE gpt_4o = OpenAIMultiModal(model="gpt-4o-mini", max_new_tokens=4096) QA_PROMPT = PromptTemplate(MULTOMODAL_QUERY_TEMPLATE) class MultimodalQueryEngine(CustomQueryEngine): """Custom multimodal Query Engine. Takes in a retriever to retrieve a set of document nodes. Also takes in a prompt template and multimodal model. """ qa_prompt: PromptTemplate retriever: BaseRetriever multi_modal_llm: OpenAIMultiModal def __init__(self, qa_prompt: Optional[PromptTemplate] = None, **kwargs) -> None: """Initialize.""" super().__init__(qa_prompt=qa_prompt or QA_PROMPT, **kwargs) def custom_query(self, query_str: str): # retrieve text nodes nodes = self.retriever.retrieve(query_str) # create ImageNode items from text nodes image_nodes = [ NodeWithScore(node=ImageNode(image_url=link)) for n in nodes if "image_link" in n.metadata and n.metadata["image_link"] not in ["", []] for link in (n.metadata["image_link"] if isinstance(n.metadata["image_link"], list) else [n.metadata["image_link"]]) if link not in ["", []] ] print("image_nodes: {}".format(image_nodes)) # create context string from text nodes, dump into the prompt context_str = "\n\n".join( [r.get_content(metadata_mode=MetadataMode.LLM) for r in nodes] ) fmt_prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) # synthesize an answer from formatted text and images llm_response = self.multi_modal_llm.complete( prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes], ) return Response( response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": nodes, "image_nodes": image_nodes}, )