File size: 6,610 Bytes
c9f26e8 ff4e3da c9f26e8 ff4e3da c9f26e8 55b549b c9f26e8 ff4e3da c9f26e8 55b549b c9f26e8 ff4e3da c9f26e8 ff4e3da c9f26e8 ff4e3da c9f26e8 ff4e3da c9f26e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import argparse
import gradio as gr
from difflib import Differ
from functools import partial
from string import Template
from utils import load_prompt, setup_gemini_client
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ai-studio-api-key", type=str, default=os.getenv("GEMINI_API_KEY"))
parser.add_argument("--vertexai", action="store_true", default=False)
parser.add_argument("--vertexai-project", type=str, default="gcp-ml-172005")
parser.add_argument("--vertexai-location", type=str, default="us-central1")
parser.add_argument("--model", type=str, default="gemini-1.5-flash")
parser.add_argument("--prompt-tmpl-path", type=str, default="configs/prompts.toml")
parser.add_argument("--css-path", type=str, default="statics/styles.css")
args = parser.parse_args()
return args
def find_attached_file(filename, attached_files):
for file in attached_files:
if file['name'] == filename:
return file
return None
def echo(message, history, state):
attached_file = None
if message['files']:
path_local = message['files'][0]
filename = os.path.basename(path_local)
attached_file = find_attached_file(filename, state["attached_files"])
if attached_file is None:
path_gcp = client.files.upload(path=path_local)
state["attached_files"].append({
"name": filename,
"path_local": path_local,
"gcp_entity": path_gcp,
"path_gcp": path_gcp.name,
"mime_type=": path_gcp.mime_type,
"expiration_time": path_gcp.expiration_time,
})
attached_file = path_gcp
# [{'role': 'user', 'metadata': None, 'content': 'asdf', 'options': None}, {'role': 'assistant', 'metadata': None, 'content': 'asdf', 'options': None}]
user_message = [message['text']]
if attached_file: user_message.append(attached_file)
chat_history = state['messages']
chat_history = chat_history + user_message
state['messages'] = chat_history
response = client.models.generate_content(
model="gemini-1.5-flash",
contents=state['messages']
)
model_response = response.text
# make summary
if state['summary'] != "":
response = client.models.generate_content(
model="gemini-1.5-flash",
contents=[
Template(
prompt_tmpl['summarization']['prompt']
).safe_substitute(
previous_summary=state['summary'],
latest_conversation=str({"user": message['text'], "assistant": model_response})
)
]
)
if state['summary'] != "":
prev_summary = state['summary_history'][-1]
else:
prev_summary = ""
d = Differ()
state['summary'] = response.text
state['summary_history'].append(response.text)
state['summary_diff_history'].append(
[
(token[2:], token[0] if token[0] != " " else None)
for token in d.compare(prev_summary, state['summary'])
]
)
return (
model_response,
state,
# state['summary'],
state['summary_diff_history'][-1],
state['summary_history'][-1],
gr.Slider(
maximum=len(state['summary_history']),
value=len(state['summary_history']),
visible=False if len(state['summary_history']) == 1 else True, interactive=True
),
)
def change_view_toggle(view_toggle):
if view_toggle == "Diff":
return (
gr.HighlightedText(visible=True),
gr.Markdown(visible=False)
)
else:
return (
gr.HighlightedText(visible=False),
gr.Markdown(visible=True)
)
def navigate_to_summary(summary_num, state):
return (
state['summary_diff_history'][summary_num-1],
state['summary_history'][summary_num-1]
)
def main(args):
style_css = open(args.css_path, "r").read()
global client, prompt_tmpl
client = setup_gemini_client(args)
prompt_tmpl = load_prompt(args)
## Gradio Blocks
with gr.Blocks(css=style_css) as demo:
# State per session
state = gr.State({
"messages": [],
"attached_files": [],
"summary": "",
"summary_history": [],
"summary_diff_history": []
})
gr.Markdown("# Adaptive Summarization")
gr.Markdown("AdaptSum stands for Adaptive Summarization. This project focuses on developing an LLM-powered system for dynamic summarization. Instead of generating entirely new summaries with each update, the system intelligently identifies and modifies only the necessary parts of the existing summary. This approach aims to create a more efficient and fluid summarization process within a continuous chat interaction with an LLM.")
with gr.Accordion("Adaptive Summary"):
with gr.Row(elem_id="view-toggle-btn-container"):
view_toggle_btn = gr.Radio(
choices=["Diff", "Markdown"],
value="Markdown",
interactive=True,
elem_id="view-toggle-btn"
)
summary_diff = gr.HighlightedText(
label="Summary so far",
combine_adjacent=True,
show_legend=True,
color_map={"+": "red", "-": "green"},
elem_classes=["summary-window"],
visible=False
)
summary_md = gr.Markdown(
label="Summary so far",
elem_classes=["summary-window"],
visible=True
)
summary_num = gr.Slider(label="summary history", minimum=1, maximum=1, step=1, show_reset_button=False, visible=False)
view_toggle_btn.change(change_view_toggle, inputs=[view_toggle_btn], outputs=[summary_diff, summary_md])
summary_num.release(navigate_to_summary, inputs=[summary_num, state], outputs=[summary_diff, summary_md])
with gr.Column("chat-window"):
gr.ChatInterface(
multimodal=True,
type="messages",
fn=echo,
additional_inputs=[state],
additional_outputs=[state, summary_diff, summary_md, summary_num],
)
return demo
if __name__ == "__main__":
args = parse_args()
demo = main(args)
demo.launch()
|