File size: 3,976 Bytes
5956319
74995d7
5956319
f7d8c6a
 
 
9c1d271
f7d8c6a
ebcc5ea
f7d8c6a
 
 
 
 
 
 
 
a1908d6
f7d8c6a
 
f316cfc
f7d8c6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13cfe75
 
f7d8c6a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread

# Lazy loading the model to meet huggingface stateless GPU requirements 

# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [50256, 50295]  # IDs of tokens where the generation should stop.
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:  # Checking if the last generated token is a stop token.
                return True
        return False


# Function to generate model predictions.
@spaces.GPU
def predict(message, history):
    torch.set_default_device("cuda")

    # Loading the tokenizer and model from Hugging Face's model hub.
    tokenizer = AutoTokenizer.from_pretrained(
        "macadeliccc/laser-dolphin-mixtral-2x7b-dpo",
        trust_remote_code=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        "macadeliccc/laser-dolphin-mixtral-2x7b-dpo",
        torch_dtype="auto",
        load_in_4bit=True,
        trust_remote_code=True
    )
    history_transformer_format = history + [[message, ""]]
    stop = StopOnTokens()

    # Formatting the input for the model.
    system_prompt = "<|im_start|>system\nYou are Dolphin, a helpful AI assistant.<|im_end|>"
    messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
    input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        temperature=0.7,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop])
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()  # Starting the generation in a separate thread.
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        if '<|im_end|>' in partial_message:  # Breaking the loop if the stop token is generated.
            break
        yield partial_message


# Setting up the Gradio chat interface.
gr.ChatInterface(predict,
                 description="""
                 <center><img src="https://huggingface.co/macadeliccc/laser-dolphin-mixtral-2x7b-dpo/resolve/main/dolphin_moe.png" width="33%"></center>\n\n
                 Chat with [macadeliccc/laser-dolphin-mixtral-2x7b-dpo](https://huggingface.co/macadeliccc/laser-dolphin-mixtral-2x7b-dpo), the first Mixture of Experts made of lasered 7b models. 
                 This model (24.2B param) scores very well on many evaluations. More information is available on the model card. Output is considered experimental.\n\n
                 ❤️ If you like this work, please follow me on [Hugging Face](https://huggingface.co/macadeliccc) and [LinkedIn](https://www.linkedin.com/in/tim-dolan-python-dev/).
                 """,
                 examples=[
                     'Can you solve the equation 2x + 3 = 11 for x?',
                     'How does Fermats last theorem impact number theory?',
                     'What is a vector in the scope of computer science rather than physics?',
                     'Use a list comprehension to create a list of squares for numbers from 1 to 10.',
                     'Recommend some popular science fiction books.',
                     'Can you write a short story about a time-traveling detective?'
                 ],
                 theme=gr.themes.Soft(primary_hue="purple"),
                 ).launch()