File size: 7,423 Bytes
f1fd41e
7ae54ea
 
580eaed
f1fd41e
7ae54ea
6f93dce
 
2f665a8
580eaed
f1fd41e
 
 
 
6f93dce
2f665a8
6f93dce
 
 
d10798f
2f665a8
6f93dce
2f665a8
 
 
 
 
 
 
f1fd41e
2f665a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f93dce
7ae54ea
 
2f665a8
 
 
 
 
7ae54ea
 
 
 
 
 
 
 
 
 
 
2f665a8
7ae54ea
 
f1fd41e
2f665a8
 
 
7ae54ea
2f665a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1fd41e
2f665a8
 
 
7ae54ea
 
 
 
 
2f665a8
7ae54ea
2f665a8
 
 
 
 
7ae54ea
 
 
 
 
 
 
 
 
 
 
 
 
 
2f665a8
7ae54ea
 
 
6f93dce
 
 
2f665a8
 
 
6f93dce
f1fd41e
7ae54ea
6f93dce
 
f1fd41e
6f93dce
 
2f665a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ae54ea
6f93dce
7ae54ea
6f93dce
2f665a8
 
 
6f93dce
 
 
 
2f665a8
 
 
 
 
 
 
 
 
f1fd41e
2f665a8
 
f1fd41e
 
7ae54ea
 
 
f1fd41e
 
 
 
 
 
 
2f665a8
f1fd41e
7ae54ea
 
f1fd41e
 
 
7ae54ea
f1fd41e
7ae54ea
f1fd41e
 
6f93dce
 
2f665a8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import os
import time
import json
import numpy as np
from pathlib import Path
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
import torch
import gc  # For garbage collection

# Create cache directory if not exists
cache_dir = Path(os.getenv('TRANSFORMERS_CACHE', '/app/cache'))
cache_dir.mkdir(parents=True, exist_ok=True)

app = Flask(__name__)
CORS(app)  # Allow cross-origin requests

# Model configuration
MODEL_NAME = "deepseek-ai/deepseek-r1-6b-chat"
MAX_NEW_TOKENS = 256
DEVICE = "cpu" if not torch.cuda.is_available() else "cuda"

# Initialize model variables
tokenizer = None
model = None

def load_model():
    """Load model on first request to save memory at startup"""
    global tokenizer, model
    
    if tokenizer is not None and model is not None:
        return True
    
    try:
        from transformers import AutoTokenizer, AutoModelForCausalLM
        print(f"Loading model {MODEL_NAME}...")
        print(f"Using device: {DEVICE}")
        print(f"Cache directory: {cache_dir}")
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_NAME,
            cache_dir=str(cache_dir)
        )
        
        # Load model with low memory settings
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            cache_dir=str(cache_dir),
            device_map="auto" if DEVICE == "cuda" else None,
            torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
            low_cpu_mem_usage=True)
        
        print("βœ… Model loaded successfully!")
        return True
    except Exception as e:
        print(f"❌ Model loading failed: {str(e)}")
        return False

def stream_generator(prompt):
    """Generator function for streaming response with thinking steps"""
    # Ensure model is loaded
    if not load_model():
        yield json.dumps({"type": "error", "content": "Model not loaded"}) + '\n'
        return
    
    # Thinking phases
    thinking_steps = [
        "πŸ” Analyzing your question...",
        "🧠 Accessing knowledge base...",
        "πŸ’‘ Formulating response...",
        "πŸ“š Verifying information..."
    ]
    
    # Stream thinking steps
    for step in thinking_steps:
        yield json.dumps({"type": "thinking", "content": step}) + '\n'
        time.sleep(0.8)  # Reduced timing for faster response
    
    # Prepare streaming generation
    try:
        inputs = tokenizer(prompt, return_tensors="pt")
        if DEVICE == "cuda":
            inputs = inputs.to("cuda")
        
        # Use custom streaming implementation
        # Start generation
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=False)
        
        # Get output sequence
        output_ids = generated_ids.sequences[0][len(inputs.input_ids[0]):]
        
        # Stream in chunks for smoother experience
        full_output = ""
        chunk_size = 3  # Number of tokens per chunk
        for i in range(0, len(output_ids), chunk_size):
            chunk_ids = output_ids[i:i+chunk_size]
            chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True)
            full_output += chunk_text
            
            yield json.dumps({
                "type": "answer",
                "content": chunk_text
            }) + '\n'
            
            # Small delay for smoother streaming
            time.sleep(0.05)
            
    except Exception as e:
        import traceback
        error_details = f"Error: {str(e)}\n{traceback.format_exc()}"
        print(error_details)
        yield json.dumps({
            "type": "error",
            "content": f"Generation error: {str(e)}"
        }) + '\n'
    
    # Signal completion
    yield json.dumps({"type": "complete"}) + '\n'
    
    # Clean up memory
    if DEVICE == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

@app.route('/stream_chat', methods=['POST'])
def stream_chat():
    data = request.get_json()
    prompt = data.get('prompt', '').strip()
    
    if not prompt:
        return jsonify({"error": "Empty prompt"}), 400
    
    return Response(
        stream_generator(prompt),
        mimetype='text/event-stream',
        headers={
            'Cache-Control': 'no-cache',
            'X-Accel-Buffering': 'no',  # Prevent Nginx buffering
            'Connection': 'keep-alive'
        }
    )

@app.route('/chat', methods=['POST'])
def chat():
    # Ensure model is loaded
    if not load_model():
        return jsonify({"error": "Model failed to load"}), 500
    
    data = request.get_json()
    prompt = data.get('prompt', '').strip()
    
    if not prompt:
        return jsonify({"error": "Empty prompt"}), 400
    
    try:
        inputs = tokenizer(prompt, return_tensors="pt")
        if DEVICE == "cuda":
            inputs = inputs.to("cuda")
            
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id)
        
        response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
        
        # Clean up memory
        if DEVICE == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
        
        return jsonify({"response": response})
    
    except Exception as e:
        import traceback
        error_details = f"Error: {str(e)}\n{traceback.format_exc()}"
        print(error_details)
        return jsonify({"error": str(e)}), 500

@app.route('/health', methods=['GET'])
def health_check():
    model_loaded = tokenizer is not None and model is not None
    
    try:
        # Check if we need to load the model
        if not model_loaded and request.args.get('load') == 'true':
            model_loaded = load_model()
    except Exception as e:
        print(f"Health check error: {str(e)}")
    
    status = {
        "status": "ok" if model_loaded else "waiting",
        "model_loaded": model_loaded,
        "device": DEVICE,
        "cache_dir": str(cache_dir),
        "max_tokens": MAX_NEW_TOKENS,
        "memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.2f}MB" 
            if torch.cuda.is_available() else "CPU"
    }
    return jsonify(status)

@app.route('/')
def home():
    return jsonify({
        "service": "DeepSeek Chat API",
        "status": "online",
        "endpoints": {
            "POST /chat": "Single-response chat",
            "POST /stream_chat": "Streaming chat with thinking steps",
            "GET /health": "Service health check"
        },
        "config": {
            "model": MODEL_NAME,
            "max_tokens": MAX_NEW_TOKENS,
            "cache_location": str(cache_dir)
        }
    })

if __name__ == '__main__':
    # Load model at startup - only if explicitly requested
    if os.getenv('PRELOAD_MODEL', 'false').lower() == 'true':
        load_model()
    
    port = int(os.environ.get("PORT", 5000))
    app.run(host='0.0.0.0', port=port)