File size: 3,663 Bytes
9fd6bb0
 
 
 
d1496ae
9fd6bb0
 
 
 
d1496ae
9fd6bb0
 
 
 
d1496ae
 
f694927
 
9fd6bb0
 
 
d1496ae
a47318d
c7686bd
9fd6bb0
 
 
 
c7686bd
 
d1496ae
9fd6bb0
 
d1496ae
 
9fd6bb0
 
 
 
 
 
 
 
 
 
e82849f
9fd6bb0
e82849f
 
 
 
 
 
 
 
9fd6bb0
d1496ae
9fd6bb0
 
 
 
 
 
 
 
 
 
 
 
d1496ae
9fd6bb0
 
 
 
 
 
 
d1496ae
9fd6bb0
 
 
 
d1496ae
 
9fd6bb0
 
d1496ae
 
9fd6bb0
 
 
 
 
 
 
 
d1496ae
9fd6bb0
 
 
d1496ae
 
d006ef5
9fd6bb0
 
4a023e5
9fd6bb0
 
d1496ae
 
9fd6bb0
 
 
 
d1496ae
 
9fd6bb0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


DESCRIPTION = """\
Shakti is a 500 million parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service
For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
"""

MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 512
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "SandLogicTechnologies/Shakti-500M-v2"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"),trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token=os.getenv("SHAKTI"),
    trust_remote_code=True

)
model.eval()


@spaces.GPU(duration=180)
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    conversation = [json.loads(os.getenv("PROMPT"))]
    # for user, assistant in chat_history:
    #     conversation.extend(
    #         [
    #             json.loads(os.getenv("PROMPT")),
    #             {"role": "user", "content": user},
    #             {"role": "assistant", "content": assistant},
    #         ]
    #     )
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
    ],
    
    stop_btn=None,
    examples=[
            ["Tell me a story"], ["write a song"]
    ],
    cache_examples=False,
)

with gr.Blocks(css="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()