diabolic6045 commited on
Commit
14d1f9b
·
verified ·
1 Parent(s): 03464d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -49
app.py CHANGED
@@ -1,63 +1,134 @@
 
 
 
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
  import spaces
 
 
 
 
 
 
 
5
 
6
- # Load the tokenizer and model
7
- tokenizer = AutoTokenizer.from_pretrained("diabolic6045/open-llama-Instruct")
8
- model = AutoModelForCausalLM.from_pretrained("diabolic6045/open-llama-Instruct")
 
 
 
 
 
 
 
 
 
 
9
  model.eval()
10
- if torch.cuda.is_available():
11
- model.to('cuda')
12
-
13
- @Spaces.GPU()
14
- def respond(
15
- message,
16
- history,
17
- system_message,
18
- max_tokens,
19
- temperature,
20
- top_p,
21
- ):
22
- # Build the conversation history
23
- conversation = f"System: {system_message}\n"
24
- for user_msg, bot_msg in history:
25
- conversation += f"User: {user_msg}\nAssistant: {bot_msg}\n"
26
- conversation += f"User: {message}\nAssistant:"
27
-
28
- # Tokenize the input
29
- inputs = tokenizer(conversation, return_tensors='pt', truncation=True, max_length=1024)
30
- if torch.cuda.is_available():
31
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
32
-
33
- # Generate the response
34
- output = model.generate(
35
- **inputs,
36
- max_new_tokens=max_tokens,
 
 
 
 
 
 
37
  do_sample=True,
38
- temperature=temperature,
39
  top_p=top_p,
40
- pad_token_id=tokenizer.eos_token_id
 
 
 
41
  )
42
- response = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
43
 
44
- # Extract the assistant's reply
45
- response = response[len(conversation):].strip()
46
- return response
47
 
48
- # Create the Gradio interface with the Ocean theme
49
- demo = gr.ChatInterface(
50
- fn=respond,
51
  additional_inputs=[
52
- gr.Textbox(value="You are a friendly Chatbot.", label="System Message"),
53
- gr.Slider(minimum=1, maximum=512, value=256, step=1, label="Max New Tokens"),
54
- gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature"),
55
- gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (Nucleus Sampling)"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ],
57
- title="Open Llama Chatbot",
58
- description="Chat with an AI assistant powered by the Open Llama Instruct model.",
59
- theme=gr.themes.Ocean(),
 
 
 
 
 
 
60
  )
61
 
 
 
 
 
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
  import gradio as gr
 
 
6
  import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # Llama 3.2 1B Instruct Finetuned on 10% of Open Hermes Dataset
12
+ This is a demo of [`diabolic6045/open-llama-Instruct`](https://huggingface.co/diabolic6045/open-llama-Instruct).
13
+ """
14
 
15
+ MAX_MAX_NEW_TOKENS = 1024
16
+ DEFAULT_MAX_NEW_TOKENS = 512
17
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
18
+
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+
21
+ model_id = "diabolic6045/open-llama-Instruct"
22
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ model_id,
25
+ device_map="auto",
26
+ torch_dtype=torch.bfloat16,
27
+ )
28
  model.eval()
29
+
30
+
31
+ @spaces.GPU(duration=90)
32
+ def generate(
33
+ message: str,
34
+ chat_history: list[tuple[str, str]],
35
+ max_new_tokens: int = 1024,
36
+ temperature: float = 0.6,
37
+ top_p: float = 0.9,
38
+ top_k: int = 50,
39
+ repetition_penalty: float = 1.2,
40
+ ) -> Iterator[str]:
41
+ conversation = []
42
+ for user, assistant in chat_history:
43
+ conversation.extend(
44
+ [
45
+ {"role": "user", "content": user},
46
+ {"role": "assistant", "content": assistant},
47
+ ]
48
+ )
49
+ conversation.append({"role": "user", "content": message})
50
+
51
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
52
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
54
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
55
+ input_ids = input_ids.to(model.device)
56
+
57
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
58
+ generate_kwargs = dict(
59
+ {"input_ids": input_ids},
60
+ streamer=streamer,
61
+ max_new_tokens=max_new_tokens,
62
  do_sample=True,
 
63
  top_p=top_p,
64
+ top_k=top_k,
65
+ temperature=temperature,
66
+ num_beams=1,
67
+ repetition_penalty=repetition_penalty,
68
  )
69
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
70
+ t.start()
71
+
72
+ outputs = []
73
+ for text in streamer:
74
+ outputs.append(text)
75
+ yield "".join(outputs)
76
 
 
 
 
77
 
78
+ chat_interface = gr.ChatInterface(
79
+ fn=generate,
 
80
  additional_inputs=[
81
+ gr.Slider(
82
+ label="Max new tokens",
83
+ minimum=1,
84
+ maximum=MAX_MAX_NEW_TOKENS,
85
+ step=1,
86
+ value=DEFAULT_MAX_NEW_TOKENS,
87
+ ),
88
+ gr.Slider(
89
+ label="Temperature",
90
+ minimum=0.1,
91
+ maximum=4.0,
92
+ step=0.1,
93
+ value=0.6,
94
+ ),
95
+ gr.Slider(
96
+ label="Top-p (nucleus sampling)",
97
+ minimum=0.05,
98
+ maximum=1.0,
99
+ step=0.05,
100
+ value=0.9,
101
+ ),
102
+ gr.Slider(
103
+ label="Top-k",
104
+ minimum=1,
105
+ maximum=1000,
106
+ step=1,
107
+ value=50,
108
+ ),
109
+ gr.Slider(
110
+ label="Repetition penalty",
111
+ minimum=1.0,
112
+ maximum=2.0,
113
+ step=0.05,
114
+ value=1.2,
115
+ ),
116
  ],
117
+ stop_btn=None,
118
+ examples=[
119
+ ["Hello there! How are you doing?"],
120
+ ["Can you explain briefly to me what is the Python programming language?"],
121
+ ["Explain the plot of Cinderella in a sentence."],
122
+ ["How many hours does it take a man to eat a Helicopter?"],
123
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
124
+ ],
125
+ cache_examples=False,
126
  )
127
 
128
+ with gr.Blocks(css="style.css", fill_height=True, theme=gr.themes.Ocean()) as demo:
129
+ gr.Markdown(DESCRIPTION)
130
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
131
+ chat_interface.render()
132
+
133
  if __name__ == "__main__":
134
+ demo.queue().launch()