File size: 5,228 Bytes
ff96a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, TextStreamer
from threading import Thread
import gradio as gr
from peft import PeftModel

model_name_or_path = "sarvamai/OpenHathi-7B-Hi-v0.1-Base"
peft_model_id = "shuvom/OpenHathi-7B-FT-v0.1_SI"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, device_map="auto")

# tokenizer.chat_template = chat_template
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)

model = PeftModel.from_pretrained(model, peft_model_id)

class ChatCompletion:
  def __init__(self, model, tokenizer, system_prompt=None):
    self.model = model
    self.tokenizer = tokenizer
    self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
    self.print_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
    # set the model in inference mode
    self.model.eval()
    self.system_prompt = system_prompt

  def get_completion(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
    if temperature < 1e-2:
      temperature = 1e-2
    messages = []
    if message_history is not None:
      messages.extend(message_history)
    elif system_prompt or self.system_prompt:
      system_prompt = system_prompt or self.system_prompt
      messages.append({"role": "system", "content":system_prompt})
    messages.append({"role": "user", "content": prompt})
    chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
    # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
    generation_kwargs = dict(max_new_tokens=max_new_tokens,
                             temperature=temperature,
                             top_p=0.95,
                             do_sample=True,
                             eos_token_id=tokenizer.eos_token_id,
                             repetition_penalty=1.2
                             )
    generated_text = self.model.generate(**inputs, streamer=self.print_streamer, **generation_kwargs)
    return generated_text

  def get_chat_completion(self, message, history):
    messages = []
    if self.system_prompt:
      messages.append({"role": "system", "content":self.system_prompt})
    for user_message, assistant_message in history:
        messages.append({"role": "user", "content": user_message})
        messages.append({"role": "system", "content": assistant_message})
    messages.append({"role": "user", "content": message})
    chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = self.tokenizer(chat_prompt, return_tensors="pt")
    # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
    generation_kwargs = dict(inputs,
                             streamer=self.streamer,
                             max_new_tokens=2048,
                             temperature=0.2,
                             top_p=0.95,
                             eos_token_id=tokenizer.eos_token_id,
                             do_sample=True,
                             repetition_penalty=1.2,
                             )
    thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
    thread.start()
    generated_text = ""
    for new_text in self.streamer:
        generated_text += new_text.replace(self.tokenizer.eos_token, "")
        yield generated_text
    thread.join()
    return generated_text

  def get_completion_without_streaming(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
    if temperature < 1e-2:
      temperature = 1e-2
    messages = []
    if message_history is not None:
      messages.extend(message_history)
    elif system_prompt or self.system_prompt:
      system_prompt = system_prompt or self.system_prompt
      messages.append({"role": "system", "content":system_prompt})
    messages.append({"role": "user", "content": prompt})
    chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
    # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
    generation_kwargs = dict(max_new_tokens=max_new_tokens,
                             temperature=temperature,
                             top_p=0.95,
                             do_sample=True,
                             repetition_penalty=1.1)
    outputs = self.model.generate(**inputs, **generation_kwargs)
    generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

text_generator = ChatCompletion(model, tokenizer, system_prompt="You are a native Hindi speaker who can converse at expert level in both Hindi and colloquial Hinglish.")

gr.ChatInterface(text_generator.get_chat_completion).queue().launch(debug=True)