YALCINKAYA commited on
Commit
e831671
·
verified ·
1 Parent(s): b59a0e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -22,11 +22,22 @@ app = Flask(__name__)
22
  CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
23
 
24
  # Load zero-shot classification pipeline
25
- classifier = pipeline("zero-shot-classification")
26
 
27
  # Load Sentence-BERT model
28
  bertmodel = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight, efficient model; choose larger if needed
29
-
 
 
 
 
 
 
 
 
 
 
 
30
  # Global variables for model and tokenizer
31
  model = None
32
  tokenizer = None
 
22
  CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
23
 
24
  # Load zero-shot classification pipeline
25
+ #classifier = pipeline("zero-shot-classification")
26
 
27
  # Load Sentence-BERT model
28
  bertmodel = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight, efficient model; choose larger if needed
29
+
30
+ # Load model with accelerator
31
+ classifier = pipeline(
32
+ "zero-shot-classification",
33
+ model="facebook/bart-large-mnli",
34
+ revision="d7645e1",
35
+ device=accelerator.device # Ensures correct device placement
36
+ )
37
+
38
+ # Move model to correct device
39
+ classifier.model = accelerator.prepare(classifier.model)
40
+
41
  # Global variables for model and tokenizer
42
  model = None
43
  tokenizer = None