File size: 2,863 Bytes
f1fd41e
 
6f93dce
 
 
 
 
f1fd41e
 
 
 
6f93dce
 
 
 
 
d10798f
 
6f93dce
f1fd41e
6f93dce
d10798f
f1fd41e
 
d10798f
f1fd41e
6f93dce
 
f1fd41e
6f93dce
f1fd41e
 
6f93dce
 
 
 
 
 
 
f1fd41e
 
 
 
 
 
 
 
 
 
 
 
 
6f93dce
 
 
 
 
 
f1fd41e
 
 
6f93dce
f1fd41e
6f93dce
f1fd41e
6f93dce
 
 
f1fd41e
 
6f93dce
 
 
 
 
 
f1fd41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f93dce
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
from pathlib import Path
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Create cache directory if not exists
cache_dir = Path(os.getenv('TRANSFORMERS_CACHE', '/app/cache'))
cache_dir.mkdir(parents=True, exist_ok=True)

app = Flask(__name__)
CORS(app)

# Model configuration
MODEL_NAME = "deepseek-ai/deepseek-r1-6b-chat"
MAX_NEW_TOKENS = 256
DEVICE = "cpu"

# Initialize model
try:
    tokenizer = AutoTokenizer.from_pretrained(  # Fixed this line
        MODEL_NAME,
        cache_dir=str(cache_dir)
    )  # Added closing parenthesis
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        cache_dir=str(cache_dir),
        device_map="auto",
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True
    )
    print("Model loaded successfully!")
except Exception as e:
    print(f"Model loading failed: {str(e)}")
    model = None

def generate_response(prompt):
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Error generating response: {str(e)}"

@app.route('/chat', methods=['POST'])
def chat():
    if not model:
        return jsonify({"error": "Model not loaded"}), 500
    
    data = request.get_json()
    if not data or 'prompt' not in data:
        return jsonify({"error": "No prompt provided"}), 400
    
    prompt = data['prompt'].strip()
    if not prompt:
        return jsonify({"error": "Empty prompt"}), 400
    
    try:
        response = generate_response(prompt)
        # Clean up extra text after the final answer
        response = response.split("</s>")[0].strip()
        return jsonify({"response": response})
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/health', methods=['GET'])
def health_check():
    status = {
        "model_loaded": bool(model),
        "device": DEVICE,
        "cache_dir": str(cache_dir),
        "memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.2f}MB" if torch.cuda.is_available() else "CPU"
    }
    return jsonify(status)

@app.route('/')
def home():
    return jsonify({
        "service": "DeepSeek Chat API",
        "endpoints": {
            "POST /chat": "Process chat prompts",
            "GET /health": "Service health check"
        },
        "config": {
            "max_tokens": MAX_NEW_TOKENS,
            "model": MODEL_NAME
        }
    })

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)