moraxgiga commited on
Commit
043ca42
·
verified ·
1 Parent(s): 9a25f52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -34
app.py CHANGED
@@ -1,47 +1,64 @@
1
  import gradio as gr
2
- import torch, os
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from transformers import StoppingCriteria, TextIteratorStreamer
5
  from threading import Thread
6
 
7
- torch.set_num_threads(2)
 
 
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
9
 
10
- # Loading the tokenizer and model from Hugging Face's model hub.
11
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN)
12
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN)
 
 
 
13
 
14
  def count_tokens(text):
 
15
  return len(tokenizer.tokenize(text))
16
 
17
- # Function to generate model predictions.
18
  def predict(message, history):
19
-
20
  formatted_prompt = f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
21
  model_inputs = tokenizer(formatted_prompt, return_tensors="pt")
22
-
23
- streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)
24
-
25
- generate_kwargs = dict(
26
- model_inputs,
27
- streamer=streamer,
28
- max_new_tokens=2048 - count_tokens(formatted_prompt),
29
- top_p=0.2,
30
- top_k=20,
31
- temperature=0.1,
32
- repetition_penalty=2.0,
33
- length_penalty=-0.5,
34
- num_beams=1
35
- )
36
- t = Thread(target=model.generate, kwargs=generate_kwargs)
37
- t.start() # Starting the generation in a separate thread.
38
- partial_message = ""
39
- for new_token in streamer:
40
- partial_message += new_token
41
- yield partial_message
42
-
43
- # Setting up the Gradio chat interface.
44
- gr.ChatInterface(predict,
45
- title="Gemma 2b Instruct Chat",
46
- description=None
47
- ).launch() # Launching the web interface.
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  from threading import Thread
5
 
6
+ # Set the number of threads for PyTorch
7
+ torch.set_num_threads(3)
8
+
9
+ # Your Hugging Face token and model identifiers
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
+ MODEL_NAME = "google/gemma-2b-it"
12
+
13
+ # Load the tokenizer
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
15
+
16
+ # Load the model and switch it to evaluation mode
17
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN).eval()
18
 
19
+ # Apply dynamic quantization
20
+ quantized_model = torch.quantization.quantize_dynamic(
21
+ model,
22
+ {torch.nn.Linear}, # Specify the layer types to quantize
23
+ dtype=torch.qint8 # Target datatype for quantized weights
24
+ )
25
 
26
  def count_tokens(text):
27
+ """Count tokens in the input text."""
28
  return len(tokenizer.tokenize(text))
29
 
 
30
  def predict(message, history):
31
+ """Generate predictions using the quantized model."""
32
  formatted_prompt = f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
33
  model_inputs = tokenizer(formatted_prompt, return_tensors="pt")
34
+
35
+ # Ensure to use the quantized model for prediction
36
+ generate_kwargs = {
37
+ "input_ids": model_inputs["input_ids"],
38
+ "max_length": 2048 - count_tokens(formatted_prompt),
39
+ "top_p": 0.2,
40
+ "top_k": 20,
41
+ "temperature": 0.1,
42
+ "repetition_penalty": 2.0,
43
+ "length_penalty": -0.5,
44
+ "num_beams": 1,
45
+ "return_dict_in_generate": True,
46
+ "output_scores": True
47
+ }
48
+
49
+ with torch.no_grad(): # Ensure no gradient is computed to save memory and computation
50
+ output = quantized_model.generate(**generate_kwargs)
51
+
52
+ # Decode and return the generated text
53
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
54
+ return generated_text
55
+
56
+ # Setting up the Gradio interface
57
+ interface = gr.Interface(fn=predict,
58
+ inputs=[gr.inputs.Textbox(label="Your message"), gr.inputs.Textbox(label="History", default="")],
59
+ outputs="text",
60
+ title="Quantized Gemma 2B Chat",
61
+ description="This is a Gradio interface for interacting with a quantized version of the Gemma 2B model.")
62
+
63
+ # Launch the interface
64
+ interface.launch()