Zakia commited on
Commit
cda3c49
·
verified ·
1 Parent(s): 2fa9a9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -2,22 +2,28 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
 
5
- # Select the best distill model for Hugging Face Spaces
6
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
7
 
8
  # Load tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
 
11
- # Load model with quantization for optimized performance
12
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
 
 
 
 
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
- quantization_config=quantization_config,
16
  device_map="auto",
 
17
  trust_remote_code=True
18
  )
19
 
20
- # Define the text generation function
21
  def generate_response(prompt):
22
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
23
  with torch.no_grad():
@@ -29,7 +35,7 @@ interface = gr.Interface(
29
  fn=generate_response,
30
  inputs=gr.Textbox(label="Enter your prompt"),
31
  outputs=gr.Textbox(label="AI Response"),
32
- title="DeepSeek-R1 Distilled LLaMA Chatbot",
33
  description="Enter a prompt and receive a response from DeepSeek-R1-Distill-Llama-8B."
34
  )
35
 
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
 
5
+ # Use a more compatible DeepSeek model
6
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
7
 
8
  # Load tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
 
11
+ # Fix quantization issue by using 4-bit
12
+ quantization_config = BitsAndBytesConfig(
13
+ load_in_4bit=True, # Use 4-bit instead of 8-bit
14
+ bnb_4bit_compute_dtype=torch.float16, # Use FP16 for better compatibility
15
+ bnb_4bit_use_double_quant=True, # Enable double quantization for efficiency
16
+ )
17
+
18
+ # Load model with optimized quantization
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_name,
 
21
  device_map="auto",
22
+ quantization_config=quantization_config,
23
  trust_remote_code=True
24
  )
25
 
26
+ # Define text generation function
27
  def generate_response(prompt):
28
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
29
  with torch.no_grad():
 
35
  fn=generate_response,
36
  inputs=gr.Textbox(label="Enter your prompt"),
37
  outputs=gr.Textbox(label="AI Response"),
38
+ title="DeepSeek-R1 Distill LLaMA Chatbot",
39
  description="Enter a prompt and receive a response from DeepSeek-R1-Distill-Llama-8B."
40
  )
41