mike23415 commited on
Commit
fa2a9d3
·
verified ·
1 Parent(s): ff6f1af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -17
app.py CHANGED
@@ -1,11 +1,16 @@
1
  import os
2
  import time
3
  import torch
 
4
  from flask import Flask, request, jsonify
5
  from flask_cors import CORS
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import gradio as gr
8
 
 
 
 
 
9
  # Global variables
10
  MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
11
  MAX_LENGTH = 2048
@@ -27,21 +32,29 @@ def load_model_and_tokenizer():
27
 
28
  print(f"Loading model: {MODEL_ID}")
29
 
30
- # Load tokenizer
31
- tokenizer = AutoTokenizer.from_pretrained(
32
- MODEL_ID,
33
- use_fast=True,
34
- )
35
-
36
- # Load model with optimizations for limited resources
37
- model = AutoModelForCausalLM.from_pretrained(
38
- MODEL_ID,
39
- device_map="auto",
40
- torch_dtype=torch.bfloat16,
41
- load_in_4bit=True,
42
- )
43
-
44
- print("Model and tokenizer loaded successfully!")
 
 
 
 
 
 
 
 
45
 
46
  # Initialize Flask app
47
  app = Flask(__name__)
@@ -205,4 +218,8 @@ if __name__ == "__main__":
205
 
206
  # Create and launch Gradio interface
207
  demo = create_ui()
208
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
1
  import os
2
  import time
3
  import torch
4
+ import warnings
5
  from flask import Flask, request, jsonify
6
  from flask_cors import CORS
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, logging
8
  import gradio as gr
9
 
10
+ # Suppress warnings
11
+ warnings.filterwarnings("ignore")
12
+ logging.set_verbosity_error()
13
+
14
  # Global variables
15
  MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
16
  MAX_LENGTH = 2048
 
32
 
33
  print(f"Loading model: {MODEL_ID}")
34
 
35
+ try:
36
+ # Load tokenizer
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ MODEL_ID,
39
+ use_fast=True,
40
+ trust_remote_code=True # Added to trust remote code
41
+ )
42
+
43
+ # Load model with optimizations for limited resources
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ MODEL_ID,
46
+ device_map="auto",
47
+ torch_dtype=torch.bfloat16,
48
+ load_in_4bit=True,
49
+ trust_remote_code=True # Added to trust remote code
50
+ )
51
+
52
+ print("Model and tokenizer loaded successfully!")
53
+ except Exception as e:
54
+ import traceback
55
+ print(f"Error loading model: {str(e)}")
56
+ print(traceback.format_exc())
57
+ raise
58
 
59
  # Initialize Flask app
60
  app = Flask(__name__)
 
218
 
219
  # Create and launch Gradio interface
220
  demo = create_ui()
221
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
222
+
223
+ # Flask won't reach here when Gradio is running
224
+ # If you want to run Flask separately:
225
+ # app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))