YALCINKAYA commited on
Commit
284c0f7
·
verified ·
1 Parent(s): 34139ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -35
app.py CHANGED
@@ -1,21 +1,23 @@
1
  import os
2
  import torch
3
  from flask import Flask, jsonify, request
4
- from flask_cors import CORS
5
- from transformers import GPTNeoForCausalLM, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig
 
 
6
  # Set the HF_HOME environment variable to a writable directory
7
- os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
8
 
9
  app = Flask(__name__)
10
 
11
  # Enable CORS for specific origins
12
- CORS(app, resources={r"api/predict/*": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
13
-
14
  # Global variables for model and tokenizer
15
  model = None
16
  tokenizer = None
17
 
18
- def get_model_and_tokenizer(model_id):
19
  global model, tokenizer
20
  if model is None or tokenizer is None:
21
  try:
@@ -23,33 +25,32 @@ def get_model_and_tokenizer(model_id):
23
  tokenizer = AutoTokenizer.from_pretrained(model_id)
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
26
- print(f"Loading model for model_id: {model_id} on {device}")
27
-
28
- bnb_config = BitsAndBytesConfig(
29
- load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True
30
  )
31
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_id, quantization_config=bnb_config, device_map="auto"
34
  )
35
 
36
- model.config.use_cache=False
37
- model.config.pretraining_tp=1
 
38
 
39
-
40
  except Exception as e:
41
  print(f"Error loading model: {e}")
42
- raise e # Raise the error to be caught in the POST request
43
- else:
44
- print(f"Model and tokenizer for {model_id} are already loaded.")
45
 
46
  def generate_response(user_input, model_id):
47
  # Ensure model and tokenizer are loaded
48
  get_model_and_tokenizer(model_id)
49
 
50
- prompt = user_input
51
-
52
- generation_config = GenerationConfig(
 
53
  penalty_alpha=0.6,
54
  do_sample=True,
55
  top_p=0.2,
@@ -57,18 +58,18 @@ def generate_response(user_input, model_id):
57
  temperature=0.3,
58
  repetition_penalty=1.2,
59
  max_new_tokens=60,
60
- pad_token_id=tokenizer.eos_token_id,
61
- stop_sequences=["User:", "Assistant:", "\n"],
62
  )
63
 
64
-
65
- inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
66
-
67
  outputs = model.generate(**inputs, generation_config=generation_config)
68
- response = (tokenizer.decode(outputs[0], skip_special_tokens=True))
69
-
70
- cleaned_response = response.replace("User:", "").replace("Assistant:", "").strip()
71
- return cleaned_response.strip().split("\n")[0] # Keep only the first line of response
 
72
 
73
  @app.route("/", methods=["GET"])
74
  def handle_get_request():
@@ -81,21 +82,21 @@ def handle_post_request():
81
  if data is None:
82
  return jsonify({"error": "No JSON data provided"}), 400
83
 
84
- message = data.get("inputs", "No message provided.")
85
- model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin") # Default model if not provided
86
 
87
  try:
88
- print(f"Loading")
89
- # Generate a response from the model
90
  model_response = generate_response(message, model_id)
91
  return jsonify({
92
- "received_message": model_response,
93
- "model_id": model_id,
94
  "status": "POST request successful!"
95
  })
96
  except Exception as e:
 
97
  print(f"Error handling POST request: {e}")
98
- return jsonify({"error": "An error occurred while processing your request."}), 500
99
 
100
  if __name__ == '__main__':
101
  app.run(host='0.0.0.0', port=7860)
 
1
  import os
2
  import torch
3
  from flask import Flask, jsonify, request
4
+ from flask_cors import CORS
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
6
+ import re
7
+
8
  # Set the HF_HOME environment variable to a writable directory
9
+ os.environ["HF_HOME"] = "/workspace/huggingface_cache"
10
 
11
  app = Flask(__name__)
12
 
13
  # Enable CORS for specific origins
14
+ CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
15
+
16
  # Global variables for model and tokenizer
17
  model = None
18
  tokenizer = None
19
 
20
+ def get_model_and_tokenizer(model_id: str):
21
  global model, tokenizer
22
  if model is None or tokenizer is None:
23
  try:
 
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
  tokenizer.pad_token = tokenizer.eos_token
27
 
28
+ print(f"Loading model for model_id: {model_id}")
29
+
30
+ bnb_config = BitsAndBytesConfig(
31
+ load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True
32
  )
33
 
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id, quantization_config=bnb_config, device_map="auto"
36
  )
37
 
38
+ model.config.use_cache = False
39
+ model.config.pretraining_tp = 1
40
+ model.config.pad_token_id = tokenizer.eos_token_id # Fix padding issue
41
 
 
42
  except Exception as e:
43
  print(f"Error loading model: {e}")
44
+ raise e
 
 
45
 
46
  def generate_response(user_input, model_id):
47
  # Ensure model and tokenizer are loaded
48
  get_model_and_tokenizer(model_id)
49
 
50
+ prompt = user_input
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+
53
+ generation_config = GenerationConfig(
54
  penalty_alpha=0.6,
55
  do_sample=True,
56
  top_p=0.2,
 
58
  temperature=0.3,
59
  repetition_penalty=1.2,
60
  max_new_tokens=60,
61
+ pad_token_id=tokenizer.eos_token_id
 
62
  )
63
 
64
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
65
+ model.to(device)
66
+
67
  outputs = model.generate(**inputs, generation_config=generation_config)
68
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
69
+
70
+ # Clean up response
71
+ cleaned_response = re.sub(r"(User:|Assistant:)", "", response).strip()
72
+ return cleaned_response.split("\n")[0] # Keep only the first line of response
73
 
74
  @app.route("/", methods=["GET"])
75
  def handle_get_request():
 
82
  if data is None:
83
  return jsonify({"error": "No JSON data provided"}), 400
84
 
85
+ message = data.get("inputs", "No message provided.")
86
+ model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin")
87
 
88
  try:
89
+ print(f"Processing request")
 
90
  model_response = generate_response(message, model_id)
91
  return jsonify({
92
+ "received_message": model_response,
93
+ "model_id": model_id,
94
  "status": "POST request successful!"
95
  })
96
  except Exception as e:
97
+ error_message = str(e) if app.debug else "An error occurred while processing your request."
98
  print(f"Error handling POST request: {e}")
99
+ return jsonify({"error": error_message}), 500
100
 
101
  if __name__ == '__main__':
102
  app.run(host='0.0.0.0', port=7860)