File size: 2,511 Bytes
3ef14f8
 
caeef30
3ef14f8
 
 
a233482
d498112
3ef14f8
 
 
 
 
 
a5971a0
67f4fec
b4f246b
67f4fec
a233482
 
1a8e94e
 
 
 
 
 
 
 
 
 
b4f246b
a233482
1a8e94e
a233482
67f4fec
a5971a0
2a943a9
 
 
 
 
 
 
 
 
 
 
 
 
3ef14f8
 
 
cd30edd
a233482
2a943a9
a233482
 
 
 
 
 
 
 
 
30aaf10
 
d506c2a
30aaf10
86a2758
30aaf10
 
a233482
34bfd19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = 'armandnlp/gpt2-TOD_finetuned_SGD'
tokenizer_TOD = AutoTokenizer.from_pretrained(model_name)
model_TOD = AutoModelForCausalLM.from_pretrained(model_name)


def generate_response(prompt):
    input_ids = tokenizer_TOD(prompt, return_tensors="pt").input_ids
    outputs = model_TOD.generate(input_ids, 
                                 do_sample=False, 
                                 max_length=1024, 
                                 eos_token_id=50262)
    return tokenizer_TOD.batch_decode(outputs)[0]

#<|context|> <|user|> I want to go to the restaurant.<|endofcontext|>


def chat(message, history):
    history = history or []
    if history == []:
        context = '<|context|> <|user|> ' + message + ' <|endofcontext|> '
    else:
        context, _ = history[-1][0].split('<|endofcontext|>')
        context += ' <|system|> ' 
        context += history[-1][1].split('<|response|>')[1]
        context = context.replace('<|endofresponse|>', '')
        context += ' <|user|> ' + message + ' <|endofcontext|> '

    output = generate_response(context)
    _ , response = output.split('<|endofcontext|>')

    history.append((message, response))
    
    return history, history

import random
def chat_test(message, history):
    history = history or []
    if message.startswith("How many"):
        response = random.randint(1, 10)
    elif message.startswith("How"):
        response = random.choice(["Great", "Good", "Okay", "Bad"])
    elif message.startswith("Where"):
        response = random.choice(["Here", "There", "Somewhere"])
    else:
        response = "I don't know"
    history.append((message, response))
    return history, history

import gradio as gr

chatbot = gr.Chatbot(color_map=("green", "gray"))

iface = gr.Interface(chat_test,
                    ["text", "state"],
                    [chatbot, "state"],
                    allow_screenshot=False,
                    allow_flagging="never",
)



"""
iface = gr.Interface(fn=generate_response,
                     inputs="text",
                     outputs="text",
                     title="gpt2-TOD",
                     examples=[["<|context|> <|user|> I'm super hungry ! I want to go to the restaurant.<|endofcontext|>"]],
                     description="Passing in a task-oriented dialogue context generates a belief state, actions to take and a response based on those actions",
                     )
"""
iface.launch(debug=True)