File size: 4,332 Bytes
8824f88
 
 
 
 
 
 
 
 
 
0adfa8b
ba0ce5d
8824f88
 
c4bb11b
8824f88
 
a74ee73
c4bb11b
8824f88
ba0ce5d
 
8824f88
 
 
 
 
 
c4bb11b
 
 
8824f88
 
 
 
 
 
 
c4bb11b
0737a9d
 
34353a1
0737a9d
8824f88
c4bb11b
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba0ce5d
8824f88
17aeee6
 
 
d355157
ba0ce5d
a74ee73
ba0ce5d
0adfa8b
c4bb11b
8824f88
 
 
 
 
 
 
 
 
 
c4bb11b
 
 
 
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1038ed
 
 
 
8824f88
b86214b
63414cd
8824f88
 
 
 
 
 
ba0ce5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from chat_interface_preference import ChatInterface

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))

if torch.cuda.is_available():
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"


@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.06,
    top_p: float = 0.95,
    top_k: int = 40,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = ChatInterface(
    fn=generate,
    prefence_techniques="dpo",
    min_turns=1,
    max_turns=10,
    repo_id="llm-human-feedback-collector-chat-interface-dpo",
    chatbot=gr.Chatbot(
        height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True
    ),
    css=style,
    cache_examples=False,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.05,
            maximum=1.2,
            step=0.05,
            value=0.2,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    examples=[
        ["""What word doesn't make sense in this row: "car, airplane, lama, bus"?"""],
        ["Write a news article about the usage of Lama's by the CSI"],
        ["What are great things cook when getting started with Asian cooking?"],
        ["Who was Anthony Bourdain?"],
    ],
    title="💪🏽🦾 LLM human-feedback collector ChatInterface (DPO) 🦾💪🏽",
    description="""This is an adaptation of the gr.ChatInferface which allows for human feedback collection for SFT, DPO and KTO.""",
)

with gr.Blocks(css="style.css") as demo:
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()