|
import pandas as pd |
|
|
|
import os, sys |
|
|
|
import ast |
|
|
|
import gradio as gr |
|
|
|
import google.generativeai as genai |
|
|
|
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"] |
|
|
|
genai.configure(api_key=GOOGLE_API_KEY) |
|
|
|
gemini_pro = genai.GenerativeModel(model_name="models/gemini-pro") |
|
gemini_pro_vision = genai.GenerativeModel(model_name="models/gemini-pro-vision") |
|
|
|
|
|
import knowledge_triples_utils as kutils |
|
|
|
|
|
all_nodes_csv_path = "AHU_17_All_Nodes_embedding.csv" |
|
|
|
all_nodes_df = pd.read_csv(all_nodes_csv_path) |
|
|
|
all_nodes_df['node_embedding'] = all_nodes_df['node_embedding'].apply(lambda x: ast.literal_eval(x) if pd.notna(x) else x) |
|
|
|
|
|
all_images_csv_path = "AHU_17_All_Images_embeddings_hf.csv" |
|
|
|
all_images_df = pd.read_csv(all_images_csv_path) |
|
|
|
all_images_df['desc_title_embedding'] = all_images_df['desc_title_embedding'].apply(lambda x: ast.literal_eval(x) if pd.notna(x) else x) |
|
|
|
|
|
|
|
def answer_query(query): |
|
|
|
|
|
matching_results_text = kutils.get_similar_text_from_query( |
|
query, |
|
all_nodes_df, |
|
column_name="node_embedding", |
|
top_n=3, |
|
print_citation=False, |
|
) |
|
|
|
|
|
matching_results_images = kutils.get_relevant_images_from_query( |
|
query, |
|
all_images_df, |
|
column_name="desc_title_embedding", |
|
top_n=3, |
|
) |
|
|
|
|
|
context_text = [] |
|
for key, value in matching_results_text.items(): |
|
context_text.append(value["node_text"]) |
|
final_context_text = "\n".join(context_text) |
|
|
|
|
|
context_images = [] |
|
for key, value in matching_results_images.items(): |
|
context_images.extend( |
|
["Image: ", value["image_object"], "Caption: ", value["image_description"]] |
|
) |
|
|
|
instructions = ''' |
|
You will answer the query based on the text context given in "text_context" and Image context given |
|
in "image_context" along with its Caption:\n |
|
Base your response on "text_context" and "image_context". Do not use any numbers or percentages that are |
|
not present in the "image_context". |
|
Context: |
|
''' |
|
|
|
final_prompt = [ |
|
"QUERY: " + query + " ANSWER: ", |
|
instructions, |
|
"text_context:", |
|
"\n".join(context_text), |
|
"image_context:", |
|
] |
|
final_prompt.extend(context_images) |
|
|
|
response = gemini_pro_vision.generate_content( |
|
final_prompt, |
|
stream=True, |
|
) |
|
|
|
response_list = [] |
|
for chunk in response: |
|
response_list.append(chunk.text) |
|
response = "".join(response_list) |
|
|
|
|
|
return response, matching_results_images[0]["image_object"] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=answer_query, |
|
inputs="textbox", |
|
outputs=["textbox", "image"] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|