File size: 3,009 Bytes
f13eeb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import pandas as pd

import os, sys

import ast

import gradio as gr

import google.generativeai as genai

GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]

genai.configure(api_key=GOOGLE_API_KEY)

gemini_pro = genai.GenerativeModel(model_name="models/gemini-pro")
gemini_pro_vision = genai.GenerativeModel(model_name="models/gemini-pro-vision")


import knowledge_triples_utils as kutils


all_nodes_csv_path = "AHU_17_All_Nodes_embedding.csv"

all_nodes_df = pd.read_csv(all_nodes_csv_path)

all_nodes_df['node_embedding'] = all_nodes_df['node_embedding'].apply(lambda x: ast.literal_eval(x) if pd.notna(x) else x)


all_images_csv_path = "AHU_17_All_Images_embeddings_hf.csv"

all_images_df = pd.read_csv(all_images_csv_path)

all_images_df['desc_title_embedding'] = all_images_df['desc_title_embedding'].apply(lambda x: ast.literal_eval(x) if pd.notna(x) else x)


# answer query by gemini
def answer_query(query):

    # Matching user text query with "node_embedding" to find relevant chunks.
    matching_results_text = kutils.get_similar_text_from_query(
        query,
        all_nodes_df,
        column_name="node_embedding",
        top_n=3,
        print_citation=False,
    )

    # Matching user text query with "desc_title_embedding" to find relevant images.
    matching_results_images = kutils.get_relevant_images_from_query(
        query,
        all_images_df,
        column_name="desc_title_embedding",
        top_n=3,
    )

    # combine all the selected relevant text chunks
    context_text = []
    for key, value in matching_results_text.items():
        context_text.append(value["node_text"])
    final_context_text = "\n".join(context_text)

    # combine all the relevant images and their description generated by Gemini
    context_images = []
    for key, value in matching_results_images.items():
        context_images.extend(
            ["Image: ", value["image_object"], "Caption: ", value["image_description"]]
        )

    instructions = '''
        You will answer the query based on the text context given in "text_context" and Image context given 
        in "image_context" along with its Caption:\n 
        Base your response on "text_context" and "image_context". Do not use any numbers or percentages that are 
        not present in the "image_context".
        Context:
        '''

    final_prompt = [
        "QUERY: " + query + " ANSWER: ",
        instructions,
        "text_context:",
        "\n".join(context_text),
        "image_context:",
    ]
    final_prompt.extend(context_images)

    response = gemini_pro_vision.generate_content(
            final_prompt,
            stream=True,
    )

    response_list = []
    for chunk in response:
        response_list.append(chunk.text)
    response = "".join(response_list)

    
    return response, matching_results_images[0]["image_object"]


demo = gr.Interface(
    fn=answer_query,
    inputs="textbox",
    outputs=["textbox", "image"]
)

if __name__ == "__main__":
    demo.launch()