JERNGOC commited on
Commit
6debb22
·
verified ·
1 Parent(s): 7283adc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ MAX_MAX_NEW_TOKENS = 1024
11
+ DEFAULT_MAX_NEW_TOKENS = 256
12
+ MAX_INPUT_TOKEN_LENGTH = 512
13
+
14
+ DESCRIPTION = """\
15
+ # OpenELM-3B-Instruct
16
+
17
+ This Space demonstrates [OpenELM-3B-Instruct](https://huggingface.co/apple/OpenELM-3B-Instruct) by Apple. Please, check the original model card for details.
18
+ You can see the other models of the OpenELM family [here](https://huggingface.co/apple/OpenELM)
19
+ The following Colab notebooks are available:
20
+ * [OpenELM-3B-Instruct (GPU)](https://gist.github.com/Norod/4f11bb36bea5c548d18f10f9d7ec09b0)
21
+ * [OpenELM-270M (CPU)](https://gist.github.com/Norod/5a311a8e0a774b5c35919913545b7af4)
22
+
23
+ You might also be interested in checking out Apple's [CoreNet Github page](https://github.com/apple/corenet?tab=readme-ov-file).
24
+
25
+ If you duplicate this space, make sure you have access to [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
26
+ because this model uses it as a tokenizer.
27
+
28
+ # Note: Use this model for only for completing sentences and instruction following.
29
+ ## While the user interface is a chatbot for convenience, this is an instruction tuned model not fine-tuned for chatbot tasks. As such, the model is not provided a chat history and will complete your text based on the last given prompt only.
30
+ """
31
+
32
+ LICENSE = """
33
+ <p/>
34
+
35
+ ---
36
+ As a derivative work of [OpenELM-3B-Instruct](https://huggingface.co/apple/OpenELM-3B-Instruct) by Apple,
37
+ this demo is governed by the original [license](https://huggingface.co/apple/OpenELM-3B-Instruct/blob/main/LICENSE).
38
+ """
39
+
40
+ if not torch.cuda.is_available():
41
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
42
+
43
+ # Global variables
44
+ model = None
45
+ tokenizer = None
46
+
47
+ if torch.cuda.is_available():
48
+ model_id = "apple/OpenELM-3B-Instruct"
49
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
50
+ tokenizer_id = "meta-llama/Llama-2-7b-hf"
51
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
52
+ if tokenizer.pad_token == None:
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+ tokenizer.pad_token_id = tokenizer.eos_token_id
55
+
56
+ @spaces.GPU
57
+ def generate(
58
+ message: str,
59
+ chat_history: list[tuple[str, str]],
60
+ max_new_tokens: int = 1024,
61
+ temperature: float = 0.6,
62
+ top_p: float = 0.9,
63
+ top_k: int = 50,
64
+ repetition_penalty: float = 1.4,
65
+ ) -> Iterator[str]:
66
+ global model, tokenizer # Access global variables
67
+
68
+ input_ids = tokenizer([message], return_tensors="pt").input_ids
69
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
70
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
71
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
72
+ input_ids = input_ids.to(model.device)
73
+
74
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
75
+ generate_kwargs = dict(
76
+ {"input_ids": input_ids},
77
+ streamer=streamer,
78
+ max_new_tokens=max_new_tokens,
79
+ do_sample=True,
80
+ top_p=top_p,
81
+ top_k=top_k,
82
+ temperature=temperature,
83
+ num_beams=1,
84
+ pad_token_id = tokenizer.eos_token_id,
85
+ repetition_penalty=repetition_penalty,
86
+ no_repeat_ngram_size=5,
87
+ early_stopping=True,
88
+ )
89
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
90
+ t.start()
91
+
92
+ outputs = []
93
+ for text in streamer:
94
+ outputs.append(text)
95
+ yield "".join(outputs)
96
+
97
+ chat_interface = gr.ChatInterface(
98
+ fn=generate,
99
+ additional_inputs=[
100
+ gr.Slider(
101
+ label="Max new tokens",
102
+ minimum=1,
103
+ maximum=MAX_MAX_NEW_TOKENS,
104
+ step=1,
105
+ value=DEFAULT_MAX_NEW_TOKENS,
106
+ ),
107
+ gr.Slider(
108
+ label="Temperature",
109
+ minimum=0.1,
110
+ maximum=4.0,
111
+ step=0.1,
112
+ value=0.6,
113
+ ),
114
+ gr.Slider(
115
+ label="Top-p (nucleus sampling)",
116
+ minimum=0.05,
117
+ maximum=1.0,
118
+ step=0.05,
119
+ value=0.9,
120
+ ),
121
+ gr.Slider(
122
+ label="Top-k",
123
+ minimum=1,
124
+ maximum=1000,
125
+ step=1,
126
+ value=50,
127
+ ),
128
+ gr.Slider(
129
+ label="Repetition penalty",
130
+ minimum=1.0,
131
+ maximum=2.0,
132
+ step=0.05,
133
+ value=1.4,
134
+ ),
135
+ ],
136
+ stop_btn=None,
137
+ examples=[
138
+ ["A recipe for a chocolate cake:"],
139
+ ["Can you explain briefly to me what is the Python programming language?"],
140
+ ["Explain the plot of Cinderella in a sentence."],
141
+ ["Question: What is the capital of France?\nAnswer:"],
142
+ ["Question: I am very tired, what should I do?\nAnswer:"],
143
+ ],
144
+ )
145
+
146
+ with gr.Blocks(css="style.css") as demo:
147
+ gr.Markdown(DESCRIPTION)
148
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
149
+ chat_interface.render()
150
+ gr.Markdown(LICENSE)
151
+
152
+ if __name__ == "__main__":
153
+ demo.queue(max_size=20).launch()