File size: 3,113 Bytes
d81d6d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from threading import Thread
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    TextIteratorStreamer
)

MODEL_ID = "universeTBD/astrollama"
WINDOW_SIZE = 4096
DEVICE = "cuda"

config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID
)

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    config=config,
    device_map="auto",
    use_safetensors=True,
    trust_remote_code=True,
    load_in_4bit=True,
    torch_dtype=torch.bfloat16
)


def generate_text(prompt: str,
                  max_new_tokens: int = 512,
                  temperature: float = 0.5,
                  top_p: float = 0.95,
                  top_k: int = 50) -> str:

    # Encode the prompt
    inputs = tokenizer([prompt],
                       return_tensors="pt",
                       add_special_tokens=False,
                       return_token_type_ids=False)
    inputs = inputs.to(DEVICE)

    # Prepare arguments for generation
    input_length = inputs["input_ids"].shape[-1]
    max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
    if temperature >= 1.0:
        temperature = 0.99
    elif temperature <= 0.0:
        temperature = 0.01
    if top_p > 1.0 or top_p <= 0.0:
        top_p = 1.0
    if top_k <= 0:
        top_k = 100
    streamer = TextIteratorStreamer(tokenizer,
                                    timeout=10.,
                                    skip_prompt=True,
                                    skip_special_tokens=True)
    generation_kwargs = dict(
        inputs=inputs,
        streamer=inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        return_full_text=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
    )

    # Generate text
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # outputs = []
    for text in streamer:
        return text


demo = gr.Interface(
    fn=generate_text,
    inputs=[
        # Prompt
        gr.Textbox(
            label="Prompt",
            container=False,
            show_label=False,
            placeholder="Enter some text...",
            scale=10,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.01,
            maximum=0.99,
            step=0.01,
            value=0.5,
        ),
        gr.Slider(
            label="Top-p (for sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.95,
        ),
        gr.Slider(
            label='Top-k (for sampling)',
            minimum=1,
            maximum=1000,
            step=1,
            value=100,
        )
    ],
    outputs=[
        gr.Textbox(
            container=False,
            show_label=False,
            placeholder="Generated output...",
            scale=10,
        )
    ],
)

demo.queue(max_size=20).launch(server_port=7878)