import os from pathlib import Path from flask import Flask, request, jsonify from flask_cors import CORS from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Create cache directory if not exists cache_dir = Path(os.getenv('TRANSFORMERS_CACHE', '/app/cache')) cache_dir.mkdir(parents=True, exist_ok=True) app = Flask(__name__) CORS(app) # Model configuration MODEL_NAME = "deepseek-ai/deepseek-r1-6b-chat" MAX_NEW_TOKENS = 256 DEVICE = "cpu" # Initialize model try: tokenizer = AutoTokenizer.from_pretrained( # Fixed this line MODEL_NAME, cache_dir=str(cache_dir) ) # Added closing parenthesis model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, cache_dir=str(cache_dir), device_map="auto", torch_dtype=torch.float32, low_cpu_mem_usage=True ) print("Model loaded successfully!") except Exception as e: print(f"Model loading failed: {str(e)}") model = None def generate_response(prompt): try: inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) outputs = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id ) return tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: return f"Error generating response: {str(e)}" @app.route('/chat', methods=['POST']) def chat(): if not model: return jsonify({"error": "Model not loaded"}), 500 data = request.get_json() if not data or 'prompt' not in data: return jsonify({"error": "No prompt provided"}), 400 prompt = data['prompt'].strip() if not prompt: return jsonify({"error": "Empty prompt"}), 400 try: response = generate_response(prompt) # Clean up extra text after the final answer response = response.split("")[0].strip() return jsonify({"response": response}) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route('/health', methods=['GET']) def health_check(): status = { "model_loaded": bool(model), "device": DEVICE, "cache_dir": str(cache_dir), "memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.2f}MB" if torch.cuda.is_available() else "CPU" } return jsonify(status) @app.route('/') def home(): return jsonify({ "service": "DeepSeek Chat API", "endpoints": { "POST /chat": "Process chat prompts", "GET /health": "Service health check" }, "config": { "max_tokens": MAX_NEW_TOKENS, "model": MODEL_NAME } }) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000)