paper_qa / gen /gemini.py
chansung's picture
Update gen/gemini.py
6e0c0c0 verified
import ast
import copy
import toml
from string import Template
from pathlib import Path
from flatdict import FlatDict
import google.generativeai as genai
from gen.utils import parse_first_json_snippet
def determine_model_name(given_image=None):
if given_image is None:
return "gemini-1.0-pro"
else:
return "gemini-1.0-pro-vision"
def construct_image_part(given_image):
return {
"mime_type": "image/jpeg",
"data": given_image
}
def call_gemini(prompt="", API_KEY=None, given_text=None, given_image=None, generation_config=None, safety_settings=None):
genai.configure(api_key=API_KEY)
if generation_config is None:
generation_config = {
"temperature": 0.8,
"top_p": 1,
"top_k": 32,
"max_output_tokens": 4096,
}
if safety_settings is None:
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
},
]
model_name = determine_model_name(given_image)
model = genai.GenerativeModel(model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings)
USER_PROMPT = prompt
if given_text is not None:
USER_PROMPT += f"""{prompt}
------------------------------------------------
{given_text}
"""
prompt_parts = [USER_PROMPT]
if given_image is not None:
prompt_parts.append(construct_image_part(given_image))
response = model.generate_content(prompt_parts)
return response.text
def try_out(prompt, given_text, gemini_api_key, given_image=None, retry_num=10):
qna_json = None
cur_retry = 0
while qna_json is None and cur_retry < retry_num:
try:
qna = call_gemini(
prompt=prompt,
given_text=given_text,
given_image=given_image,
API_KEY=gemini_api_key
)
qna_json = parse_first_json_snippet(qna)
except Exception as e:
cur_retry = cur_retry + 1
print(f"......retry {e}")
return qna_json
def get_basic_qa(text, gemini_api_key, trucate=7000):
prompts = toml.load(Path('.') / 'constants' / 'prompts.toml')
basic_qa = try_out(prompts['basic_qa']['prompt'], text[:trucate], gemini_api_key=gemini_api_key)
return basic_qa
def get_deep_qa(text, basic_qa, gemini_api_key, trucate=7000):
prompts = toml.load(Path('.') / 'constants' / 'prompts.toml')
title = basic_qa['title']
qnas = copy.deepcopy(basic_qa['qna'])
for idx, qna in enumerate(qnas):
q = qna['question']
a_expert = qna['answers']['expert']
depth_search_prompt = Template(prompts['deep_qa']['prompt']).substitute(
title=title, previous_question=q, previous_answer=a_expert, tone="in-depth"
)
breath_search_prompt = Template(prompts['deep_qa']['prompt']).substitute(
title=title, previous_question=q, previous_answer=a_expert, tone="broad"
)
depth_search_response = {}
breath_search_response = {}
while 'follow up question' not in depth_search_response or \
'answers' not in depth_search_response or \
'eli5' not in depth_search_response['answers'] or \
'expert' not in depth_search_response['answers']:
depth_search_response = try_out(depth_search_prompt, text[:trucate], gemini_api_key=gemini_api_key)
while 'follow up question' not in breath_search_response or \
'answers' not in breath_search_response or \
'eli5' not in breath_search_response['answers'] or \
'expert' not in breath_search_response['answers']:
breath_search_response = try_out(breath_search_prompt, text[:trucate], gemini_api_key=gemini_api_key)
if depth_search_response is not None:
qna['additional_depth_q'] = depth_search_response
if breath_search_response is not None:
qna['additional_breath_q'] = breath_search_response
qna = FlatDict(qna)
qna_tmp = copy.deepcopy(qna)
for k in qna_tmp:
value = qna.pop(k)
qna[f'{idx}_{k}'] = value
basic_qa.update(ast.literal_eval(str(qna)))
return basic_qa