Spaces:
Runtime error
Runtime error
File size: 5,684 Bytes
65cfc9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import re
import argparse
import openai
import gradio as gr
SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user.""" # noqa: E501
USER_FORMAT = """Interview Dialogues:
{input_txt}
Please select the rounds containing one of following tags: {pos_tags}.
Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like "speaker_name speaking_time: tag, reason".""" # noqa: E501
def preprocess(input_txt, max_length=4000, max_convs=4):
speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
input_txt = speaker_pattern.split(input_txt)
input_txt = [x.strip().replace('\n', ' ') for x in input_txt]
conversations = []
for idx, txt in enumerate(input_txt):
if txt.startswith('说话人'):
if idx < len(input_txt) - 1:
if not input_txt[idx + 1].startswith('说话人'):
conv = [txt, input_txt[idx + 1]]
else:
conv = [txt, '']
while len(''.join(conv)) > max_length:
pruned_len = max_length - len(''.join(conv[0]))
pruned_conv = [txt, conv[1][:pruned_len]]
conversations.append(pruned_conv)
conv = [txt, conv[-1][pruned_len:]]
conversations.append(conv)
input_txt_list = ['']
for conv in conversations:
conv_length = len(''.join(conv))
if len(input_txt_list[-1]) + conv_length >= max_length:
input_txt_list.append('')
elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs:
input_txt_list.append('')
input_txt_list[-1] += ''.join(conv)
processed_txt_list = []
for input_txt in input_txt_list:
input_txt = ''.join(input_txt)
input_txt = speaker_pattern.sub(r'\n\1: ', input_txt)
processed_txt_list.append(input_txt.strip())
return processed_txt_list
def chatgpt(messages, temperature=0.0):
try:
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=temperature
)
return completion.choices[0].message.content
except Exception as err:
print(err)
return chatgpt(messages, temperature)
def llm(pos_tags, neg_tags, input_txt):
user = USER_FORMAT.format(input_txt=input_txt, pos_tags=pos_tags)
messages = [
{'role': 'system',
'content': SYSTEM_PROMPT},
{'role': 'user',
'content': user}]
response = chatgpt(messages)
print(f'USER:\n\n{user}')
print(f'RESPONSE:\n\n{response}')
return response
def postprocess(input_txt, output_txt_list):
speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
output_txt = []
for txt in output_txt_list:
if len(speaker_pattern.findall(txt)) > 0:
output_txt.append(txt)
output_txt = ''.join(output_txt)
speakers = set(speaker_pattern.findall(input_txt))
output_txt = speaker_pattern.split(output_txt)
results = []
for idx, txt in enumerate(output_txt):
if txt.startswith('说话人'):
if txt not in speakers:
continue
if idx < len(output_txt) - 1:
if not output_txt[idx + 1].startswith('说话人'):
res = txt + output_txt[idx + 1]
else:
res = txt
results.append(res.strip())
return '\n'.join(results)
def filter(pos_tags, neg_tags, input_txt):
input_txt_list = preprocess(input_txt)
output_txt_list = []
for txt in input_txt_list:
output_txt = llm(pos_tags, neg_tags, txt)
output_txt_list.append(output_txt)
output_txt = postprocess(input_txt, output_txt_list)
return output_txt
if __name__ == '__main__':
parser = argparse.ArgumentParser()
args = parser.parse_args()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=0.3):
with gr.Row():
pos_txt = gr.Textbox(
lines=2,
label='Postive Tags',
elem_id='pos_textbox',
placeholder='Enter positive tags split by semicolon')
with gr.Row():
neg_txt = gr.Textbox(
lines=2,
visible=False,
label='Negative Tags',
elem_id='neg_textbox',
placeholder='Enter negative tags split by semicolon')
with gr.Row():
input_txt = gr.Textbox(
lines=5,
label='Input',
elem_id='input_textbox',
placeholder='Enter text and press submit')
with gr.Row():
submit = gr.Button('Submit')
with gr.Row():
clear = gr.Button('Clear')
with gr.Column(scale=0.7):
output_txt = gr.Textbox(
label='Output',
elem_id='output_textbox')
output_txt = output_txt.style(height=690)
submit.click(
filter,
[pos_txt, neg_txt, input_txt],
[output_txt])
clear.click(
lambda: ['', '', ''],
None,
pos_txt, neg_txt, input_txt)
demo.queue(concurrency_count=6)
demo.launch()
|