GPT4News / app.py
stevengrove
add requirements
0e5afc4
raw
history blame
5.68 kB
import re
import argparse
import openai
import gradio as gr
SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user.""" # noqa: E501
USER_FORMAT = """Interview Dialogues:
{input_txt}
Please select the rounds containing one of following tags: {pos_tags}.
Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like "speaker_name speaking_time: tag, reason".""" # noqa: E501
def preprocess(input_txt, max_length=4000, max_convs=4):
speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
input_txt = speaker_pattern.split(input_txt)
input_txt = [x.strip().replace('\n', ' ') for x in input_txt]
conversations = []
for idx, txt in enumerate(input_txt):
if txt.startswith('说话人'):
if idx < len(input_txt) - 1:
if not input_txt[idx + 1].startswith('说话人'):
conv = [txt, input_txt[idx + 1]]
else:
conv = [txt, '']
while len(''.join(conv)) > max_length:
pruned_len = max_length - len(''.join(conv[0]))
pruned_conv = [txt, conv[1][:pruned_len]]
conversations.append(pruned_conv)
conv = [txt, conv[-1][pruned_len:]]
conversations.append(conv)
input_txt_list = ['']
for conv in conversations:
conv_length = len(''.join(conv))
if len(input_txt_list[-1]) + conv_length >= max_length:
input_txt_list.append('')
elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs:
input_txt_list.append('')
input_txt_list[-1] += ''.join(conv)
processed_txt_list = []
for input_txt in input_txt_list:
input_txt = ''.join(input_txt)
input_txt = speaker_pattern.sub(r'\n\1: ', input_txt)
processed_txt_list.append(input_txt.strip())
return processed_txt_list
def chatgpt(messages, temperature=0.0):
try:
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=temperature
)
return completion.choices[0].message.content
except Exception as err:
print(err)
return chatgpt(messages, temperature)
def llm(pos_tags, neg_tags, input_txt):
user = USER_FORMAT.format(input_txt=input_txt, pos_tags=pos_tags)
messages = [
{'role': 'system',
'content': SYSTEM_PROMPT},
{'role': 'user',
'content': user}]
response = chatgpt(messages)
print(f'USER:\n\n{user}')
print(f'RESPONSE:\n\n{response}')
return response
def postprocess(input_txt, output_txt_list):
speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
output_txt = []
for txt in output_txt_list:
if len(speaker_pattern.findall(txt)) > 0:
output_txt.append(txt)
output_txt = ''.join(output_txt)
speakers = set(speaker_pattern.findall(input_txt))
output_txt = speaker_pattern.split(output_txt)
results = []
for idx, txt in enumerate(output_txt):
if txt.startswith('说话人'):
if txt not in speakers:
continue
if idx < len(output_txt) - 1:
if not output_txt[idx + 1].startswith('说话人'):
res = txt + output_txt[idx + 1]
else:
res = txt
results.append(res.strip())
return '\n'.join(results)
def filter(pos_tags, neg_tags, input_txt):
input_txt_list = preprocess(input_txt)
output_txt_list = []
for txt in input_txt_list:
output_txt = llm(pos_tags, neg_tags, txt)
output_txt_list.append(output_txt)
output_txt = postprocess(input_txt, output_txt_list)
return output_txt
if __name__ == '__main__':
parser = argparse.ArgumentParser()
args = parser.parse_args()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=0.3):
with gr.Row():
pos_txt = gr.Textbox(
lines=2,
label='Postive Tags',
elem_id='pos_textbox',
placeholder='Enter positive tags split by semicolon')
with gr.Row():
neg_txt = gr.Textbox(
lines=2,
visible=False,
label='Negative Tags',
elem_id='neg_textbox',
placeholder='Enter negative tags split by semicolon')
with gr.Row():
input_txt = gr.Textbox(
lines=5,
label='Input',
elem_id='input_textbox',
placeholder='Enter text and press submit')
with gr.Row():
submit = gr.Button('Submit')
with gr.Row():
clear = gr.Button('Clear')
with gr.Column(scale=0.7):
output_txt = gr.Textbox(
label='Output',
elem_id='output_textbox')
output_txt = output_txt.style(height=690)
submit.click(
filter,
[pos_txt, neg_txt, input_txt],
[output_txt])
clear.click(
lambda: ['', '', ''],
None,
pos_txt, neg_txt, input_txt)
demo.queue(concurrency_count=6)
demo.launch()