Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
from flask import Flask, jsonify, request
|
4 |
from flask_cors import CORS
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
|
|
|
6 |
import re
|
7 |
import traceback
|
8 |
|
@@ -17,6 +18,7 @@ CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "ht
|
|
17 |
# Global variables for model and tokenizer
|
18 |
model = None
|
19 |
tokenizer = None
|
|
|
20 |
|
21 |
def get_model_and_tokenizer(model_id: str):
|
22 |
global model, tokenizer
|
@@ -36,11 +38,14 @@ def get_model_and_tokenizer(model_id: str):
|
|
36 |
model = AutoModelForCausalLM.from_pretrained(
|
37 |
model_id, quantization_config=bnb_config, device_map="auto"
|
38 |
)
|
39 |
-
|
40 |
model.config.use_cache = False
|
41 |
model.config.pretraining_tp = 1
|
42 |
model.config.pad_token_id = tokenizer.eos_token_id # Fix padding issue
|
43 |
|
|
|
|
|
|
|
44 |
except Exception as e:
|
45 |
print("Error loading model:")
|
46 |
print(traceback.format_exc()) # Logs the full error traceback
|
@@ -51,7 +56,7 @@ def generate_response(user_input, model_id):
|
|
51 |
get_model_and_tokenizer(model_id)
|
52 |
|
53 |
prompt = user_input
|
54 |
-
device =
|
55 |
|
56 |
generation_config = GenerationConfig(
|
57 |
penalty_alpha=0.6,
|
@@ -103,4 +108,4 @@ def handle_post_request():
|
|
103 |
return jsonify({"error": str(e)}), 500
|
104 |
|
105 |
if __name__ == '__main__':
|
106 |
-
app.run(host='0.0.0.0', port=7860)
|
|
|
3 |
from flask import Flask, jsonify, request
|
4 |
from flask_cors import CORS
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
|
6 |
+
from accelerate import Accelerator
|
7 |
import re
|
8 |
import traceback
|
9 |
|
|
|
18 |
# Global variables for model and tokenizer
|
19 |
model = None
|
20 |
tokenizer = None
|
21 |
+
accelerator = Accelerator()
|
22 |
|
23 |
def get_model_and_tokenizer(model_id: str):
|
24 |
global model, tokenizer
|
|
|
38 |
model = AutoModelForCausalLM.from_pretrained(
|
39 |
model_id, quantization_config=bnb_config, device_map="auto"
|
40 |
)
|
41 |
+
|
42 |
model.config.use_cache = False
|
43 |
model.config.pretraining_tp = 1
|
44 |
model.config.pad_token_id = tokenizer.eos_token_id # Fix padding issue
|
45 |
|
46 |
+
# Ensure model is placed on the correct device using accelerate
|
47 |
+
model = accelerator.prepare(model)
|
48 |
+
|
49 |
except Exception as e:
|
50 |
print("Error loading model:")
|
51 |
print(traceback.format_exc()) # Logs the full error traceback
|
|
|
56 |
get_model_and_tokenizer(model_id)
|
57 |
|
58 |
prompt = user_input
|
59 |
+
device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup
|
60 |
|
61 |
generation_config = GenerationConfig(
|
62 |
penalty_alpha=0.6,
|
|
|
108 |
return jsonify({"error": str(e)}), 500
|
109 |
|
110 |
if __name__ == '__main__':
|
111 |
+
app.run(host='0.0.0.0', port=7860)
|