Spaces:
Sleeping
Sleeping
File size: 3,663 Bytes
9fd6bb0 d1496ae 9fd6bb0 d1496ae 9fd6bb0 d1496ae f694927 9fd6bb0 d1496ae a47318d c7686bd 9fd6bb0 c7686bd d1496ae 9fd6bb0 d1496ae 9fd6bb0 e82849f 9fd6bb0 e82849f 9fd6bb0 d1496ae 9fd6bb0 d1496ae 9fd6bb0 d1496ae 9fd6bb0 d1496ae 9fd6bb0 d1496ae 9fd6bb0 d1496ae 9fd6bb0 d1496ae d006ef5 9fd6bb0 4a023e5 9fd6bb0 d1496ae 9fd6bb0 d1496ae 9fd6bb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """\
Shakti is a 500 million parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service
For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
"""
MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 512
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "SandLogicTechnologies/Shakti-500M-v2"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"),trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
token=os.getenv("SHAKTI"),
trust_remote_code=True
)
model.eval()
@spaces.GPU(duration=180)
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
conversation = [json.loads(os.getenv("PROMPT"))]
# for user, assistant in chat_history:
# conversation.extend(
# [
# json.loads(os.getenv("PROMPT")),
# {"role": "user", "content": user},
# {"role": "assistant", "content": assistant},
# ]
# )
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
],
stop_btn=None,
examples=[
["Tell me a story"], ["write a song"]
],
cache_examples=False,
)
with gr.Blocks(css="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|