Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,8 @@ app = Flask(__name__)
|
|
16 |
CORS(app) # Allow cross-origin requests
|
17 |
|
18 |
# Model configuration
|
19 |
-
|
|
|
20 |
MAX_NEW_TOKENS = 256
|
21 |
DEVICE = "cpu" if not torch.cuda.is_available() else "cuda"
|
22 |
|
@@ -32,24 +33,44 @@ def load_model():
|
|
32 |
return True
|
33 |
|
34 |
try:
|
35 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
36 |
print(f"Loading model {MODEL_NAME}...")
|
37 |
print(f"Using device: {DEVICE}")
|
38 |
print(f"Cache directory: {cache_dir}")
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# Load tokenizer
|
41 |
tokenizer = AutoTokenizer.from_pretrained(
|
42 |
MODEL_NAME,
|
43 |
-
cache_dir=str(cache_dir)
|
|
|
44 |
)
|
45 |
|
46 |
-
#
|
|
|
|
|
|
|
|
|
47 |
model = AutoModelForCausalLM.from_pretrained(
|
48 |
MODEL_NAME,
|
49 |
cache_dir=str(cache_dir),
|
50 |
device_map="auto" if DEVICE == "cuda" else None,
|
51 |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
52 |
-
|
|
|
|
|
|
|
|
|
53 |
|
54 |
print("✅ Model loaded successfully!")
|
55 |
return True
|
@@ -79,7 +100,15 @@ def stream_generator(prompt):
|
|
79 |
|
80 |
# Prepare streaming generation
|
81 |
try:
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
if DEVICE == "cuda":
|
84 |
inputs = inputs.to("cuda")
|
85 |
|
@@ -163,7 +192,15 @@ def chat():
|
|
163 |
return jsonify({"error": "Empty prompt"}), 400
|
164 |
|
165 |
try:
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
if DEVICE == "cuda":
|
168 |
inputs = inputs.to("cuda")
|
169 |
|
|
|
16 |
CORS(app) # Allow cross-origin requests
|
17 |
|
18 |
# Model configuration
|
19 |
+
# Use DeepSeek R1 Distill Qwen 7B model
|
20 |
+
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
|
21 |
MAX_NEW_TOKENS = 256
|
22 |
DEVICE = "cpu" if not torch.cuda.is_available() else "cuda"
|
23 |
|
|
|
33 |
return True
|
34 |
|
35 |
try:
|
36 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
37 |
print(f"Loading model {MODEL_NAME}...")
|
38 |
print(f"Using device: {DEVICE}")
|
39 |
print(f"Cache directory: {cache_dir}")
|
40 |
|
41 |
+
# Use 4-bit quantization for memory efficiency if on CUDA
|
42 |
+
if DEVICE == "cuda":
|
43 |
+
quantization_config = BitsAndBytesConfig(
|
44 |
+
load_in_4bit=True,
|
45 |
+
bnb_4bit_compute_dtype=torch.float16,
|
46 |
+
bnb_4bit_quant_type="nf4",
|
47 |
+
bnb_4bit_use_double_quant=True
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
quantization_config = None
|
51 |
+
|
52 |
# Load tokenizer
|
53 |
tokenizer = AutoTokenizer.from_pretrained(
|
54 |
MODEL_NAME,
|
55 |
+
cache_dir=str(cache_dir),
|
56 |
+
trust_remote_code=True
|
57 |
)
|
58 |
|
59 |
+
# Configure token if HF_TOKEN is set
|
60 |
+
hf_token = os.environ.get("HF_TOKEN")
|
61 |
+
token_kwargs = {"token": hf_token} if hf_token else {}
|
62 |
+
|
63 |
+
# Load model with appropriate settings for the device
|
64 |
model = AutoModelForCausalLM.from_pretrained(
|
65 |
MODEL_NAME,
|
66 |
cache_dir=str(cache_dir),
|
67 |
device_map="auto" if DEVICE == "cuda" else None,
|
68 |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
69 |
+
quantization_config=quantization_config,
|
70 |
+
low_cpu_mem_usage=True,
|
71 |
+
trust_remote_code=True,
|
72 |
+
**token_kwargs
|
73 |
+
)
|
74 |
|
75 |
print("✅ Model loaded successfully!")
|
76 |
return True
|
|
|
100 |
|
101 |
# Prepare streaming generation
|
102 |
try:
|
103 |
+
# Format prompt for the model
|
104 |
+
if "mistral" in MODEL_NAME.lower():
|
105 |
+
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
|
106 |
+
elif "deepseek" in MODEL_NAME.lower():
|
107 |
+
formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
108 |
+
else:
|
109 |
+
formatted_prompt = prompt
|
110 |
+
|
111 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt")
|
112 |
if DEVICE == "cuda":
|
113 |
inputs = inputs.to("cuda")
|
114 |
|
|
|
192 |
return jsonify({"error": "Empty prompt"}), 400
|
193 |
|
194 |
try:
|
195 |
+
# Format prompt for the model
|
196 |
+
if "mistral" in MODEL_NAME.lower():
|
197 |
+
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
|
198 |
+
elif "deepseek" in MODEL_NAME.lower():
|
199 |
+
formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
200 |
+
else:
|
201 |
+
formatted_prompt = prompt
|
202 |
+
|
203 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt")
|
204 |
if DEVICE == "cuda":
|
205 |
inputs = inputs.to("cuda")
|
206 |
|