File size: 5,684 Bytes
65cfc9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import re
import argparse

import openai
import gradio as gr


SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user."""  # noqa: E501

USER_FORMAT = """Interview Dialogues:
{input_txt}

Please select the rounds containing one of following tags: {pos_tags}.
Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason  for each selected round. Do NOT output the content. Each output item should be like "speaker_name speaking_time: tag, reason"."""  # noqa: E501


def preprocess(input_txt, max_length=4000, max_convs=4):
    speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
    input_txt = speaker_pattern.split(input_txt)
    input_txt = [x.strip().replace('\n', ' ') for x in input_txt]

    conversations = []
    for idx, txt in enumerate(input_txt):
        if txt.startswith('说话人'):
            if idx < len(input_txt) - 1:
                if not input_txt[idx + 1].startswith('说话人'):
                    conv = [txt, input_txt[idx + 1]]
                else:
                    conv = [txt, '']
                while len(''.join(conv)) > max_length:
                    pruned_len = max_length - len(''.join(conv[0]))
                    pruned_conv = [txt, conv[1][:pruned_len]]
                    conversations.append(pruned_conv)
                    conv = [txt, conv[-1][pruned_len:]]
                conversations.append(conv)

    input_txt_list = ['']
    for conv in conversations:
        conv_length = len(''.join(conv))
        if len(input_txt_list[-1]) + conv_length >= max_length:
            input_txt_list.append('')
        elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs:
            input_txt_list.append('')
        input_txt_list[-1] += ''.join(conv)

    processed_txt_list = []
    for input_txt in input_txt_list:
        input_txt = ''.join(input_txt)
        input_txt = speaker_pattern.sub(r'\n\1: ', input_txt)
        processed_txt_list.append(input_txt.strip())
    return processed_txt_list


def chatgpt(messages, temperature=0.0):
    try:
        completion = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=messages,
            temperature=temperature
        )
        return completion.choices[0].message.content
    except Exception as err:
        print(err)
        return chatgpt(messages, temperature)


def llm(pos_tags, neg_tags, input_txt):
    user = USER_FORMAT.format(input_txt=input_txt, pos_tags=pos_tags)
    messages = [
        {'role': 'system',
         'content': SYSTEM_PROMPT},
        {'role': 'user',
         'content': user}]
    response = chatgpt(messages)
    print(f'USER:\n\n{user}')
    print(f'RESPONSE:\n\n{response}')
    return response


def postprocess(input_txt, output_txt_list):
    speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
    output_txt = []
    for txt in output_txt_list:
        if len(speaker_pattern.findall(txt)) > 0:
            output_txt.append(txt)
    output_txt = ''.join(output_txt)
    speakers = set(speaker_pattern.findall(input_txt))
    output_txt = speaker_pattern.split(output_txt)

    results = []
    for idx, txt in enumerate(output_txt):
        if txt.startswith('说话人'):
            if txt not in speakers:
                continue
            if idx < len(output_txt) - 1:
                if not output_txt[idx + 1].startswith('说话人'):
                    res = txt + output_txt[idx + 1]
                else:
                    res = txt
                results.append(res.strip())
    return '\n'.join(results)


def filter(pos_tags, neg_tags, input_txt):
    input_txt_list = preprocess(input_txt)
    output_txt_list = []
    for txt in input_txt_list:
        output_txt = llm(pos_tags, neg_tags, txt)
        output_txt_list.append(output_txt)
    output_txt = postprocess(input_txt, output_txt_list)
    return output_txt


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column(scale=0.3):
                with gr.Row():
                    pos_txt = gr.Textbox(
                        lines=2,
                        label='Postive Tags',
                        elem_id='pos_textbox',
                        placeholder='Enter positive tags split by semicolon')
                with gr.Row():
                    neg_txt = gr.Textbox(
                        lines=2,
                        visible=False,
                        label='Negative Tags',
                        elem_id='neg_textbox',
                        placeholder='Enter negative tags split by semicolon')
                with gr.Row():
                    input_txt = gr.Textbox(
                        lines=5,
                        label='Input',
                        elem_id='input_textbox',
                        placeholder='Enter text and press submit')
                with gr.Row():
                    submit = gr.Button('Submit')
                with gr.Row():
                    clear = gr.Button('Clear')
            with gr.Column(scale=0.7):
                output_txt = gr.Textbox(
                    label='Output',
                    elem_id='output_textbox')
                output_txt = output_txt.style(height=690)
            submit.click(
                filter,
                [pos_txt, neg_txt, input_txt],
                [output_txt])
            clear.click(
                lambda: ['', '', ''],
                None,
                pos_txt, neg_txt, input_txt)

        demo.queue(concurrency_count=6)
        demo.launch()