moraxgiga commited on
Commit
9f8e55d
·
verified ·
1 Parent(s): 8f97780

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -48
app.py CHANGED
@@ -1,64 +1,49 @@
1
  import gradio as gr
2
  import torch, os
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  from threading import Thread
5
 
6
- # Set the number of threads for PyTorch
7
  torch.set_num_threads(3)
8
 
9
- # Your Hugging Face token and model identifiers
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
- MODEL_NAME = "google/gemma-2b-it"
12
 
13
- # Load the tokenizer
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
15
-
16
- # Load the model and switch it to evaluation mode
17
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN).eval()
18
-
19
- # Apply dynamic quantization
20
- quantized_model = torch.quantization.quantize_dynamic(
21
- model,
22
- {torch.nn.Linear}, # Specify the layer types to quantize
23
- dtype=torch.qint8 # Target datatype for quantized weights
24
- )
25
 
26
  def count_tokens(text):
27
- """Count tokens in the input text."""
28
  return len(tokenizer.tokenize(text))
29
 
 
30
  def predict(message, history):
31
- """Generate predictions using the quantized model."""
32
  formatted_prompt = f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
33
  model_inputs = tokenizer(formatted_prompt, return_tensors="pt")
34
-
35
- # Ensure to use the quantized model for prediction
36
- generate_kwargs = {
37
- "input_ids": model_inputs["input_ids"],
38
- "max_length": 2048 - count_tokens(formatted_prompt),
39
- "top_p": 0.2,
40
- "top_k": 20,
41
- "temperature": 0.1,
42
- "repetition_penalty": 2.0,
43
- "length_penalty": -0.5,
44
- "num_beams": 1,
45
- "return_dict_in_generate": True,
46
- "output_scores": True
47
- }
48
-
49
- with torch.no_grad(): # Ensure no gradient is computed to save memory and computation
50
- output = quantized_model.generate(**generate_kwargs)
51
-
52
- # Decode and return the generated text
53
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
54
- return generated_text
55
-
56
- # Setting up the Gradio interface
57
- interface = gr.Interface(fn=predict,
58
- inputs=[gr.inputs.Textbox(label="Your message"), gr.inputs.Textbox(label="History", default="")],
59
- outputs="text",
60
- title="Quantized Gemma 2B Chat",
61
- description="This is a Gradio interface for interacting with a quantized version of the Gemma 2B model.")
62
-
63
- # Launch the interface
64
- interface.launch()
 
1
  import gradio as gr
2
  import torch, os
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from transformers import StoppingCriteria, TextIteratorStreamer
5
  from threading import Thread
6
 
 
7
  torch.set_num_threads(3)
8
 
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
10
 
11
+ # Loading the tokenizer and model from Hugging Face's model hub.
12
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN)
13
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN).eval()
 
 
 
 
 
 
 
 
 
14
 
15
  def count_tokens(text):
16
+
17
  return len(tokenizer.tokenize(text))
18
 
19
+ # Function to generate model predictions.
20
  def predict(message, history):
21
+
22
  formatted_prompt = f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
23
  model_inputs = tokenizer(formatted_prompt, return_tensors="pt")
24
+
25
+ streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)
26
+
27
+ generate_kwargs = dict(
28
+ model_inputs,
29
+ streamer=streamer,
30
+ max_new_tokens=2048 - count_tokens(formatted_prompt),
31
+ top_p=0.2,
32
+ top_k=20,
33
+ temperature=0.1,
34
+ repetition_penalty=2.0,
35
+ length_penalty=-0.5,
36
+ num_beams=1
37
+ )
38
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
39
+ t.start() # Starting the generation in a separate thread.
40
+ partial_message = ""
41
+ for new_token in streamer:
42
+ partial_message += new_token
43
+ yield partial_message
44
+
45
+ # Setting up the Gradio chat interface.
46
+ gr.ChatInterface(predict,
47
+ title="Gemma 2b Instruct Chat",
48
+ description=None
49
+ ).launch() # Launching the web interface.