mike23415 commited on
Commit
6f93dce
·
verified ·
1 Parent(s): 4c35444

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+
6
+ app = Flask(__name__)
7
+ CORS(app)
8
+
9
+ # Model configuration
10
+ MODEL_NAME = "deepseek-ai/deepseek-r1-6b-chat"
11
+ MAX_NEW_TOKENS = 512
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Initialize model and tokenizer
15
+ try:
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_NAME,
19
+ device_map="auto",
20
+ torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32
21
+ )
22
+ print("Model loaded successfully!")
23
+ except Exception as e:
24
+ print(f"Model loading failed: {str(e)}")
25
+ model = None
26
+
27
+ def generate_response(prompt):
28
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
29
+ outputs = model.generate(
30
+ **inputs,
31
+ max_new_tokens=MAX_NEW_TOKENS,
32
+ do_sample=True,
33
+ temperature=0.7,
34
+ top_p=0.9,
35
+ pad_token_id=tokenizer.eos_token_id
36
+ )
37
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+
39
+ @app.route('/chat', methods=['POST'])
40
+ def chat():
41
+ if not model:
42
+ return jsonify({"error": "Model not loaded"}), 500
43
+
44
+ data = request.json
45
+ prompt = data.get("prompt", "")
46
+
47
+ if not prompt:
48
+ return jsonify({"error": "No prompt provided"}), 400
49
+
50
+ try:
51
+ response = generate_response(prompt)
52
+ return jsonify({"response": response})
53
+
54
+ except Exception as e:
55
+ return jsonify({"error": str(e)}), 500
56
+
57
+ @app.route('/health', methods=['GET'])
58
+ def health_check():
59
+ status = "ready" if model else "unavailable"
60
+ return jsonify({"status": status})
61
+
62
+ if __name__ == '__main__':
63
+ app.run(host='0.0.0.0', port=5000)