from flask import Flask, request, jsonify, Response, stream_with_context from flask_cors import CORS import os import torch import time import logging import threading import queue import json from transformers import AutoTokenizer, AutoModelForCausalLM # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) # Fix caching issue on Hugging Face Spaces os.environ["TRANSFORMERS_CACHE"] = "/tmp" os.environ["HF_HOME"] = "/tmp" os.environ["XDG_CACHE_HOME"] = "/tmp" app = Flask(__name__) CORS(app) # Enable CORS for all routes device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Global model variables tokenizer = None model = None # Initialize models once on startup def initialize_models(): global tokenizer, model try: logger.info("Loading language model...") # You can change the model here if needed model_name = "Qwen/Qwen2.5-1.5B-Instruct" # Good balance of quality and speed for CPU # Load tokenizer with caching logger.info(f"Loading tokenizer: {model_name}") tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=True # Use the fast tokenizers when available ) # Load model with optimizations for CPU logger.info(f"Loading model: {model_name}") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, # Use float16 for lower memory device_map="cpu", # Explicitly set to CPU low_cpu_mem_usage=True, # Optimize memory loading offload_folder="offload" # Use disk offloading if needed ) # Handle padding tokens if tokenizer.pad_token is None: logger.info("Setting pad token to EOS token") tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id # Set up model configuration for better generation model.config.do_sample = True # Enable sampling model.config.temperature = 0.7 # Default temperature model.config.top_p = 0.9 # Default top_p logger.info("Models initialized successfully") except Exception as e: logger.error(f"Error initializing models: {str(e)}") raise # Function to simulate "thinking" process def thinking_process(message, result_queue): """ This function simulates a thinking process and puts the result in the queue. It includes both an explicit thinking stage and then a generation stage. """ try: # Simulate explicit thinking stage logger.info(f"Thinking about: '{message}'") # Pause to simulate deeper thinking (helps with more complex queries) time.sleep(1) # Create thoughtful prompt with system message and thinking instructions prompt = f"""<|im_start|>system You are a helpful, friendly, and thoughtful AI assistant. Let's approach the user's request step by step: 1. First, understand what the user is really asking 2. Consider the key aspects we need to address 3. Think about the best way to structure the response 4. Provide clear, accurate information in a conversational tone Always think carefully before responding, consider different angles, and provide thoughtful, detailed answers. <|im_end|> <|im_start|>user {message}<|im_end|> <|im_start|>assistant """ # Handle inputs inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to('cpu') for k, v in inputs.items()} # Generate answer with streaming streamer = TextStreamer(tokenizer, result_queue) # Simulate thinking first by sending some initial dots result_queue.put("Let me think about this...") time.sleep(0.5) # Generate response - we use a temperature of 0.7 for more thoughtful outputs # and top_p for nucleus sampling to avoid repetitive or generic responses try: model.generate( **inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True, streamer=streamer, num_beams=2, # Using 2 beams helps with coherence no_repeat_ngram_size=3, repetition_penalty=1.2 # Discourages token repetition ) except Exception as e: logger.error(f"Model generation error: {str(e)}") result_queue.put(f"\n\nI apologize, but I encountered an error while processing your request.") # Signal generation is complete result_queue.put(None) except Exception as e: logger.error(f"Error in thinking process: {str(e)}") result_queue.put(f"I apologize, but I encountered an error while processing your request: {str(e)}") # Signal generation is complete result_queue.put(None) # TextStreamer class for token-by-token generation class TextStreamer: def __init__(self, tokenizer, queue): self.tokenizer = tokenizer self.queue = queue self.current_tokens = [] def put(self, token_ids): self.current_tokens.extend(token_ids.tolist()) text = self.tokenizer.decode(self.current_tokens, skip_special_tokens=True) self.queue.put(text) def end(self): pass # API route for home page @app.route('/') def home(): return jsonify({"message": "AI Chat API is running!"}) # API route for streaming chat responses @app.route('/chat', methods=['POST', 'GET']) def chat(): # Handle both POST JSON and GET query parameters for flexibility if request.method == 'POST': try: data = request.get_json() message = data.get("message", "") except: # If JSON parsing fails, try form data message = request.form.get("message", "") else: # GET message = request.args.get("message", "") if not message: return jsonify({"error": "Message is required"}), 400 try: def generate(): # Signal the start of streaming with headers yield "retry: 1000\n" yield "event: message\n" # Show thinking indicator yield f"data: [Thinking...]\n\n" # Create a queue for communication between threads result_queue = queue.Queue() # Start thinking in a separate thread thread = threading.Thread(target=thinking_process, args=(message, result_queue)) thread.daemon = True # Make thread die when main thread exits thread.start() # Stream results as they become available previous_text = "" while True: try: result = result_queue.get(block=True, timeout=30) # 30 second timeout if result is None: # End of generation break # Only yield the new part of the text if isinstance(result, str): new_part = result[len(previous_text):] previous_text = result if new_part: yield f"data: {json.dumps(new_part)}\n\n" time.sleep(0.01) # Small delay for more natural typing effect except queue.Empty: # Timeout occurred yield "data: [Generation timeout. The model is taking too long to respond.]\n\n" break yield "data: [DONE]\n\n" return Response( stream_with_context(generate()), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'X-Accel-Buffering': 'no' # Disable buffering for Nginx } ) except Exception as e: logger.error(f"Error processing chat request: {str(e)}") return jsonify({"error": f"An error occurred: {str(e)}"}), 500 # Simple API for non-streaming chat (fallback) @app.route('/chat-simple', methods=['POST']) def chat_simple(): data = request.get_json() message = data.get("message", "") if not message: return jsonify({"error": "Message is required"}), 400 try: # Create prompt with system message prompt = f"""<|im_start|>system You are a helpful, friendly, and thoughtful AI assistant. Think carefully and provide informative, detailed responses. <|im_end|> <|im_start|>user {message}<|im_end|> <|im_start|>assistant """ # Handle inputs inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to('cpu') for k, v in inputs.items()} # Generate answer output = model.generate( **inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True, num_beams=1, no_repeat_ngram_size=3 ) # Decode and format answer answer = tokenizer.decode(output[0], skip_special_tokens=True) # Clean up the response if "<|im_end|>" in answer: answer = answer.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip() return jsonify({"response": answer}) except Exception as e: logger.error(f"Error processing chat request: {str(e)}") return jsonify({"error": f"An error occurred: {str(e)}"}), 500 if __name__ == "__main__": try: # Initialize models at startup initialize_models() logger.info("Starting Flask application") app.run(host="0.0.0.0", port=7860) except Exception as e: logger.critical(f"Failed to start application: {str(e)}")