Zakia commited on
Commit
2fa9a9c
·
verified ·
1 Parent(s): 977ffc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -1,18 +1,21 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
- # Model name
6
- model_name = "deepseek-ai/DeepSeek-R1"
7
 
8
  # Load tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
 
11
- # Load model with quantization
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_name,
 
 
14
  trust_remote_code=True
15
- ).to("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
  # Define the text generation function
18
  def generate_response(prompt):
@@ -26,9 +29,9 @@ interface = gr.Interface(
26
  fn=generate_response,
27
  inputs=gr.Textbox(label="Enter your prompt"),
28
  outputs=gr.Textbox(label="AI Response"),
29
- title="DeepSeek-R1 Chatbot",
30
- description="Enter a prompt and receive a response from DeepSeek-R1."
31
  )
32
 
33
  # Launch the app
34
- interface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
 
5
+ # Select the best distill model for Hugging Face Spaces
6
+ model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
7
 
8
  # Load tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
 
11
+ # Load model with quantization for optimized performance
12
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
+ quantization_config=quantization_config,
16
+ device_map="auto",
17
  trust_remote_code=True
18
+ )
19
 
20
  # Define the text generation function
21
  def generate_response(prompt):
 
29
  fn=generate_response,
30
  inputs=gr.Textbox(label="Enter your prompt"),
31
  outputs=gr.Textbox(label="AI Response"),
32
+ title="DeepSeek-R1 Distilled LLaMA Chatbot",
33
+ description="Enter a prompt and receive a response from DeepSeek-R1-Distill-Llama-8B."
34
  )
35
 
36
  # Launch the app
37
+ interface.launch()