mike23415 commited on
Commit
f1fd41e
·
verified ·
1 Parent(s): 3586fbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -21
app.py CHANGED
@@ -1,23 +1,34 @@
 
 
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
 
 
 
 
6
  app = Flask(__name__)
7
  CORS(app)
8
 
9
  # Model configuration
10
  MODEL_NAME = "deepseek-ai/deepseek-r1-6b-chat"
11
- MAX_NEW_TOKENS = 512
12
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- # Initialize model and tokenizer
15
  try:
16
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  MODEL_NAME,
 
19
  device_map="auto",
20
- torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32
 
21
  )
22
  print("Model loaded successfully!")
23
  except Exception as e:
@@ -25,39 +36,64 @@ except Exception as e:
25
  model = None
26
 
27
  def generate_response(prompt):
28
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
29
- outputs = model.generate(
30
- **inputs,
31
- max_new_tokens=MAX_NEW_TOKENS,
32
- do_sample=True,
33
- temperature=0.7,
34
- top_p=0.9,
35
- pad_token_id=tokenizer.eos_token_id
36
- )
37
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
38
 
39
  @app.route('/chat', methods=['POST'])
40
  def chat():
41
  if not model:
42
  return jsonify({"error": "Model not loaded"}), 500
43
 
44
- data = request.json
45
- prompt = data.get("prompt", "")
 
46
 
 
47
  if not prompt:
48
- return jsonify({"error": "No prompt provided"}), 400
49
 
50
  try:
51
  response = generate_response(prompt)
 
 
52
  return jsonify({"response": response})
53
-
54
  except Exception as e:
55
  return jsonify({"error": str(e)}), 500
56
 
57
  @app.route('/health', methods=['GET'])
58
  def health_check():
59
- status = "ready" if model else "unavailable"
60
- return jsonify({"status": status})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == '__main__':
63
  app.run(host='0.0.0.0', port=5000)
 
1
+ import os
2
+ from pathlib import Path
3
  from flask import Flask, request, jsonify
4
  from flask_cors import CORS
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import torch
7
 
8
+ # Create cache directory if not exists
9
+ cache_dir = Path(os.getenv('TRANSFORMERS_CACHE', '/app/cache'))
10
+ cache_dir.mkdir(parents=True, exist_ok=True)
11
+
12
  app = Flask(__name__)
13
  CORS(app)
14
 
15
  # Model configuration
16
  MODEL_NAME = "deepseek-ai/deepseek-r1-6b-chat"
17
+ MAX_NEW_TOKENS = 256 # Reduced for free tier limits
18
+ DEVICE = "cpu" # Force CPU for Hugging Face Spaces
19
 
20
+ # Initialize model
21
  try:
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ MODEL_NAME,
24
+ cache_dir=str(cache_dir)
25
+
26
  model = AutoModelForCausalLM.from_pretrained(
27
  MODEL_NAME,
28
+ cache_dir=str(cache_dir),
29
  device_map="auto",
30
+ torch_dtype=torch.float32,
31
+ low_cpu_mem_usage=True
32
  )
33
  print("Model loaded successfully!")
34
  except Exception as e:
 
36
  model = None
37
 
38
  def generate_response(prompt):
39
+ try:
40
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
41
+ outputs = model.generate(
42
+ **inputs,
43
+ max_new_tokens=MAX_NEW_TOKENS,
44
+ temperature=0.7,
45
+ top_p=0.9,
46
+ do_sample=True,
47
+ pad_token_id=tokenizer.eos_token_id
48
+ )
49
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ except Exception as e:
51
+ return f"Error generating response: {str(e)}"
52
 
53
  @app.route('/chat', methods=['POST'])
54
  def chat():
55
  if not model:
56
  return jsonify({"error": "Model not loaded"}), 500
57
 
58
+ data = request.get_json()
59
+ if not data or 'prompt' not in data:
60
+ return jsonify({"error": "No prompt provided"}), 400
61
 
62
+ prompt = data['prompt'].strip()
63
  if not prompt:
64
+ return jsonify({"error": "Empty prompt"}), 400
65
 
66
  try:
67
  response = generate_response(prompt)
68
+ # Clean up extra text after the final answer
69
+ response = response.split("</s>")[0].strip()
70
  return jsonify({"response": response})
 
71
  except Exception as e:
72
  return jsonify({"error": str(e)}), 500
73
 
74
  @app.route('/health', methods=['GET'])
75
  def health_check():
76
+ status = {
77
+ "model_loaded": bool(model),
78
+ "device": DEVICE,
79
+ "cache_dir": str(cache_dir),
80
+ "memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.2f}MB" if torch.cuda.is_available() else "CPU"
81
+ }
82
+ return jsonify(status)
83
+
84
+ @app.route('/')
85
+ def home():
86
+ return jsonify({
87
+ "service": "DeepSeek Chat API",
88
+ "endpoints": {
89
+ "POST /chat": "Process chat prompts",
90
+ "GET /health": "Service health check"
91
+ },
92
+ "config": {
93
+ "max_tokens": MAX_NEW_TOKENS,
94
+ "model": MODEL_NAME
95
+ }
96
+ })
97
 
98
  if __name__ == '__main__':
99
  app.run(host='0.0.0.0', port=5000)