Spaces:
Runtime error
Runtime error
import re | |
import argparse | |
import openai | |
import gradio as gr | |
SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user.""" # noqa: E501 | |
USER_FORMAT = """Interview Dialogues: | |
{input_txt} | |
Please select the rounds containing one of following tags: {pos_tags}. | |
Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like "speaker_name speaking_time: tag, reason".""" # noqa: E501 | |
def preprocess(input_txt, max_length=4000, max_convs=4): | |
speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)') | |
input_txt = speaker_pattern.split(input_txt) | |
input_txt = [x.strip().replace('\n', ' ') for x in input_txt] | |
conversations = [] | |
for idx, txt in enumerate(input_txt): | |
if txt.startswith('说话人'): | |
if idx < len(input_txt) - 1: | |
if not input_txt[idx + 1].startswith('说话人'): | |
conv = [txt, input_txt[idx + 1]] | |
else: | |
conv = [txt, ''] | |
while len(''.join(conv)) > max_length: | |
pruned_len = max_length - len(''.join(conv[0])) | |
pruned_conv = [txt, conv[1][:pruned_len]] | |
conversations.append(pruned_conv) | |
conv = [txt, conv[-1][pruned_len:]] | |
conversations.append(conv) | |
input_txt_list = [''] | |
for conv in conversations: | |
conv_length = len(''.join(conv)) | |
if len(input_txt_list[-1]) + conv_length >= max_length: | |
input_txt_list.append('') | |
elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs: | |
input_txt_list.append('') | |
input_txt_list[-1] += ''.join(conv) | |
processed_txt_list = [] | |
for input_txt in input_txt_list: | |
input_txt = ''.join(input_txt) | |
input_txt = speaker_pattern.sub(r'\n\1: ', input_txt) | |
processed_txt_list.append(input_txt.strip()) | |
return processed_txt_list | |
def chatgpt(messages, temperature=0.0): | |
try: | |
completion = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
temperature=temperature | |
) | |
return completion.choices[0].message.content | |
except Exception as err: | |
print(err) | |
return chatgpt(messages, temperature) | |
def llm(pos_tags, neg_tags, input_txt): | |
user = USER_FORMAT.format(input_txt=input_txt, pos_tags=pos_tags) | |
messages = [ | |
{'role': 'system', | |
'content': SYSTEM_PROMPT}, | |
{'role': 'user', | |
'content': user}] | |
response = chatgpt(messages) | |
print(f'USER:\n\n{user}') | |
print(f'RESPONSE:\n\n{response}') | |
return response | |
def postprocess(input_txt, output_txt_list): | |
speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)') | |
output_txt = [] | |
for txt in output_txt_list: | |
if len(speaker_pattern.findall(txt)) > 0: | |
output_txt.append(txt) | |
output_txt = ''.join(output_txt) | |
speakers = set(speaker_pattern.findall(input_txt)) | |
output_txt = speaker_pattern.split(output_txt) | |
results = [] | |
for idx, txt in enumerate(output_txt): | |
if txt.startswith('说话人'): | |
if txt not in speakers: | |
continue | |
if idx < len(output_txt) - 1: | |
if not output_txt[idx + 1].startswith('说话人'): | |
res = txt + output_txt[idx + 1] | |
else: | |
res = txt | |
results.append(res.strip()) | |
return '\n'.join(results) | |
def filter(pos_tags, neg_tags, input_txt): | |
input_txt_list = preprocess(input_txt) | |
output_txt_list = [] | |
for txt in input_txt_list: | |
output_txt = llm(pos_tags, neg_tags, txt) | |
output_txt_list.append(output_txt) | |
output_txt = postprocess(input_txt, output_txt_list) | |
return output_txt | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
args = parser.parse_args() | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=0.3): | |
with gr.Row(): | |
pos_txt = gr.Textbox( | |
lines=2, | |
label='Postive Tags', | |
elem_id='pos_textbox', | |
placeholder='Enter positive tags split by semicolon') | |
with gr.Row(): | |
neg_txt = gr.Textbox( | |
lines=2, | |
visible=False, | |
label='Negative Tags', | |
elem_id='neg_textbox', | |
placeholder='Enter negative tags split by semicolon') | |
with gr.Row(): | |
input_txt = gr.Textbox( | |
lines=5, | |
label='Input', | |
elem_id='input_textbox', | |
placeholder='Enter text and press submit') | |
with gr.Row(): | |
submit = gr.Button('Submit') | |
with gr.Row(): | |
clear = gr.Button('Clear') | |
with gr.Column(scale=0.7): | |
output_txt = gr.Textbox( | |
label='Output', | |
elem_id='output_textbox') | |
output_txt = output_txt.style(height=690) | |
submit.click( | |
filter, | |
[pos_txt, neg_txt, input_txt], | |
[output_txt]) | |
clear.click( | |
lambda: ['', '', ''], | |
None, | |
pos_txt, neg_txt, input_txt) | |
demo.queue(concurrency_count=6) | |
demo.launch() | |