File size: 2,103 Bytes
b075822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e32aa7
 
 
 
 
 
 
 
b075822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e32aa7
b075822
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from llama_index.llms.openai import OpenAI
from llama_index.core import load_index_from_storage, get_response_synthesizer
import matplotlib.pyplot as plt
import os
from PIL import Image
from llama_index.core import PromptTemplate
from awsfunctions import download_files_from_s3, check_file_exists_in_s3
import tempfile, shutil
import streamlit as st

st.cache_resource()
def get_image_from_s3(image_path):
    temp_dir = tempfile.mkdtemp()
    download_files_from_s3(temp_dir, [image_path])
    image =  Image.open(os.path.join(temp_dir, image_path))
    shutil.rmtree(temp_dir)
    return image

def plot_images(image_paths):
    images_shown = 0
    plt.figure(figsize=(16, 9))
    for img_path in image_paths:
        if check_file_exists_in_s3(img_path):
            image = get_image_from_s3(img_path)
            st.image(image)
            # plt.subplot(2, 3, images_shown + 1)
            # plt.imshow(image)
            # plt.xticks([])
            # plt.yticks([])
            # images_shown += 1
            # if images_shown >= 6:
            #     break

def retrieve_and_query(query, retriever_engine):
    retrieval_results = retriever_engine.retrieve(query)
    
    qa_tmpl_str = (
        "Context information is below.\n"
        "---------------------\n"
        "{context_str}\n"
        "---------------------\n"
        "Given the context information , "
        "answer the query in detail.\n"
        "Query: {query_str}\n"
        "Answer: "
    )
    qa_tmpl = PromptTemplate(qa_tmpl_str)

    llm = OpenAI(model="gpt-4o-mini", temperature=0)
    response_synthesizer = get_response_synthesizer(response_mode="refine", text_qa_template=qa_tmpl, llm=llm)

    response = response_synthesizer.synthesize(query, nodes=retrieval_results)
    
    retrieved_image_path_list = []
    for node in retrieval_results:
        if (node.metadata['file_type'] == 'image/jpeg') or (node.metadata['file_type'] == 'image/png'):
            if node.score > 0.25:
                retrieved_image_path_list.append(node.metadata['file_path'])
    
    return response, retrieved_image_path_list