ahu17_pub / app.py
anyuanay's picture
upload 3 main files
f13eeb8 verified
import pandas as pd
import os, sys
import ast
import gradio as gr
import google.generativeai as genai
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]
genai.configure(api_key=GOOGLE_API_KEY)
gemini_pro = genai.GenerativeModel(model_name="models/gemini-pro")
gemini_pro_vision = genai.GenerativeModel(model_name="models/gemini-pro-vision")
import knowledge_triples_utils as kutils
all_nodes_csv_path = "AHU_17_All_Nodes_embedding.csv"
all_nodes_df = pd.read_csv(all_nodes_csv_path)
all_nodes_df['node_embedding'] = all_nodes_df['node_embedding'].apply(lambda x: ast.literal_eval(x) if pd.notna(x) else x)
all_images_csv_path = "AHU_17_All_Images_embeddings_hf.csv"
all_images_df = pd.read_csv(all_images_csv_path)
all_images_df['desc_title_embedding'] = all_images_df['desc_title_embedding'].apply(lambda x: ast.literal_eval(x) if pd.notna(x) else x)
# answer query by gemini
def answer_query(query):
# Matching user text query with "node_embedding" to find relevant chunks.
matching_results_text = kutils.get_similar_text_from_query(
query,
all_nodes_df,
column_name="node_embedding",
top_n=3,
print_citation=False,
)
# Matching user text query with "desc_title_embedding" to find relevant images.
matching_results_images = kutils.get_relevant_images_from_query(
query,
all_images_df,
column_name="desc_title_embedding",
top_n=3,
)
# combine all the selected relevant text chunks
context_text = []
for key, value in matching_results_text.items():
context_text.append(value["node_text"])
final_context_text = "\n".join(context_text)
# combine all the relevant images and their description generated by Gemini
context_images = []
for key, value in matching_results_images.items():
context_images.extend(
["Image: ", value["image_object"], "Caption: ", value["image_description"]]
)
instructions = '''
You will answer the query based on the text context given in "text_context" and Image context given
in "image_context" along with its Caption:\n
Base your response on "text_context" and "image_context". Do not use any numbers or percentages that are
not present in the "image_context".
Context:
'''
final_prompt = [
"QUERY: " + query + " ANSWER: ",
instructions,
"text_context:",
"\n".join(context_text),
"image_context:",
]
final_prompt.extend(context_images)
response = gemini_pro_vision.generate_content(
final_prompt,
stream=True,
)
response_list = []
for chunk in response:
response_list.append(chunk.text)
response = "".join(response_list)
return response, matching_results_images[0]["image_object"]
demo = gr.Interface(
fn=answer_query,
inputs="textbox",
outputs=["textbox", "image"]
)
if __name__ == "__main__":
demo.launch()