Bot_Development / core /multimodal.py
dsmultimedika's picture
fix : update code
0767396
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.retrievers import BaseRetriever
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.core.schema import ImageNode, NodeWithScore, MetadataMode
from llama_index.core.prompts import PromptTemplate
from llama_index.core.base.response.schema import Response
from typing import Optional
from core.prompt import MULTOMODAL_QUERY_TEMPLATE
gpt_4o = OpenAIMultiModal(model="gpt-4o-mini", max_new_tokens=4096)
QA_PROMPT = PromptTemplate(MULTOMODAL_QUERY_TEMPLATE)
class MultimodalQueryEngine(CustomQueryEngine):
"""Custom multimodal Query Engine.
Takes in a retriever to retrieve a set of document nodes.
Also takes in a prompt template and multimodal model.
"""
qa_prompt: PromptTemplate
retriever: BaseRetriever
multi_modal_llm: OpenAIMultiModal
def __init__(self, qa_prompt: Optional[PromptTemplate] = None, **kwargs) -> None:
"""Initialize."""
super().__init__(qa_prompt=qa_prompt or QA_PROMPT, **kwargs)
def custom_query(self, query_str: str):
# retrieve text nodes
nodes = self.retriever.retrieve(query_str)
# create ImageNode items from text nodes
image_nodes = [
NodeWithScore(node=ImageNode(image_url=link))
for n in nodes
if "image_link" in n.metadata
and n.metadata["image_link"] not in ["", []]
for link in (n.metadata["image_link"] if isinstance(n.metadata["image_link"], list) else [n.metadata["image_link"]])
if link not in ["", []]
]
print("image_nodes: {}".format(image_nodes))
# create context string from text nodes, dump into the prompt
context_str = "\n\n".join(
[r.get_content(metadata_mode=MetadataMode.LLM) for r in nodes]
)
fmt_prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
# synthesize an answer from formatted text and images
llm_response = self.multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[image_node.node for image_node in image_nodes],
)
return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": nodes, "image_nodes": image_nodes},
)