File size: 4,359 Bytes
2aea843
6dcdc90
d36aee0
252de57
 
 
 
d36aee0
252de57
2aea843
98fd48f
 
 
 
 
0634305
252de57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8987d5f
 
 
 
 
 
252de57
8987d5f
252de57
8987d5f
 
252de57
 
 
 
 
69830ca
 
 
 
 
 
 
42ac4a1
252de57
df889e0
3eced92
69830ca
35abf4c
 
252de57
 
 
 
 
 
 
 
 
 
d36aee0
a3015fe
 
 
 
 
1376587
0e9349e
1376587
 
 
064f7a3
1376587
064f7a3
92a6dd7
1376587
 
 
 
a4f154b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import time
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread

# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained("soketlabs/pragna-1b", token=os.environ.get('HF_TOKEN'))
model = AutoModelForCausalLM.from_pretrained(
    "soketlabs/pragna-1b", 
    token=os.environ.get('HF_TOKEN'),
    revision='3c5b8b1309f7d89710331ba2f164570608af0de7'
)
model.load_adapter('soketlabs/pragna-1b-it-v0.1', token=os.environ.get('HF_TOKEN'))

# using CUDA for an optimal experience
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [2]  # IDs of tokens where the generation should stop.
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:  # Checking if the last generated token is a stop token.
                return True
        return False


# Function to generate model predictions.
def predict(message, history):
    history_transformer_format = history + [[message, ""]]
    stop = StopOnTokens()

    sys_prompt = 'You are Pragna, an AI built by Soket AI Labs. You should never lie and always tell facts. Help the user as much as you can and be open to say I dont know this if you are not sure of the answer'

    eos_token = tokenizer.eos_token
    
    messages = f'<|system|>\n{sys_prompt}{eos_token}'
    
    # Formatting the input for the model.
    messages += "</s>".join(["</s>".join(["<|user|>\n" + item[0], "<|assistant|>\n" + item[1]])
                        for item in history_transformer_format])

    print(messages)
    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        # max_new_tokens=300,
        # do_sample=True,
        # top_p=0.95,
        # top_k=50,
        # temperature=0.3,
        # repetition_penalty=10.,
        # num_beams=1,
        max_new_tokens=300,
        do_sample=True,
        top_k=5,
        num_beams=1,
        use_cache=False,
        temperature=0.2,
        repetition_penalty=1.1,
        stopping_criteria=StoppingCriteriaList([stop])
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()  # Starting the generation in a separate thread.
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        if '</s>' in partial_message:  # Breaking the loop if the stop token is generated.
            break
        yield partial_message

def slow_echo(message, history):
    for i in range(len(message)):
        time.sleep(0.05)
        yield "You typed: " + message[: i+1]

demo = gr.ChatInterface(
    predict,
    chatbot=gr.Chatbot(height=300),
    textbox=gr.Textbox(placeholder="Try Pragna SFT", container=False, scale=7),
    title="pragna-1b-it",
    description="Disclaimer: An initial checkpoint of the instruction tuned model is made available as a research preview. It is hereby cautioned that the model has the potential to produce hallucinatory and plausible yet inaccurate statements. Users are advised to exercise discretion when utilizing the generated content.",
    theme="soft",
    examples=['Tell me about India', 'मुझे भारत के बारे में बताओ?', 'भारत के प्रधान मंत्री कौन हैं', 'भारत को आजादी कब मिली', 'আমাকে ভারত সম্পর্কে বলুন', 'ભારતની રાજધાની શું છે?', 'મને ભારત વિશે કહો ', 'কলকাতার ঐতিহাসিক তাৎপর্য কী। বিস্তারিত বলুন।'],
    cache_examples=False,
    retry_btn=None,
    undo_btn="Delete Previous",
    clear_btn="Clear",
).queue()

if __name__ == "__main__":
    demo.launch()