from flask import Flask, request, jsonify, Response, stream_with_context from flask_cors import CORS import os import torch import time import logging import threading import queue from transformers import AutoTokenizer, AutoModelForCausalLM # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Fix caching issue on Hugging Face Spaces os.environ["TRANSFORMERS_CACHE"] = "/tmp" os.environ["HF_HOME"] = "/tmp" os.environ["XDG_CACHE_HOME"] = "/tmp" app = Flask(__name__) CORS(app) # Enable CORS for all routes device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Global model variables tokenizer = None model = None # Initialize models once on startup def initialize_models(): global tokenizer, model try: logger.info("Loading language model...") model_name = "Qwen/Qwen2.5-1.5B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, # Use float16 for lower memory on CPU device_map="cpu", # Explicitly set to CPU low_cpu_mem_usage=True # Optimize memory loading ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id logger.info("Models initialized successfully") except Exception as e: logger.error(f"Error initializing models: {str(e)}") raise # Function to simulate "thinking" process def thinking_process(message, result_queue): """ This function simulates a thinking process and puts the result in the queue """ try: # Simulate thinking process logger.info(f"Thinking about: '{message}'") # Create prompt with system message prompt = f"""<|im_start|>system You are a helpful, friendly, and thoughtful AI assistant. Think carefully and provide informative, detailed responses. <|im_end|> <|im_start|>user {message}<|im_end|> <|im_start|>assistant """ # Handle inputs inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to('cpu') for k, v in inputs.items()} # Generate answer with streaming streamer = TextStreamer(tokenizer, result_queue) # Generate response model.generate( **inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True, streamer=streamer, num_beams=1, no_repeat_ngram_size=3 ) # Signal generation is complete result_queue.put(None) except Exception as e: logger.error(f"Error in thinking process: {str(e)}") result_queue.put(f"I apologize, but I encountered an error while processing your request: {str(e)}") # Signal generation is complete result_queue.put(None) # TextStreamer class for token-by-token generation class TextStreamer: def __init__(self, tokenizer, queue): self.tokenizer = tokenizer self.queue = queue self.current_tokens = [] def put(self, token_ids): self.current_tokens.extend(token_ids.tolist()) text = self.tokenizer.decode(self.current_tokens, skip_special_tokens=True) self.queue.put(text) def end(self): pass # API route for home page @app.route('/') def home(): return jsonify({"message": "AI Chat API is running!"}) # API route for streaming chat responses @app.route('/chat', methods=['POST']) def chat(): data = request.get_json() message = data.get("message", "") if not message: return jsonify({"error": "Message is required"}), 400 try: def generate(): # Create a queue for communication between threads result_queue = queue.Queue() # Start thinking in a separate thread thread = threading.Thread(target=thinking_process, args=(message, result_queue)) thread.start() # Stream results as they become available previous_text = "" while True: try: result = result_queue.get(block=True, timeout=30) # 30 second timeout if result is None: # End of generation break # Only yield the new part of the text if isinstance(result, str): new_part = result[len(previous_text):] previous_text = result if new_part: yield f"data: {new_part}\n\n" except queue.Empty: # Timeout occurred yield "data: [Generation timeout. The model is taking too long to respond.]\n\n" break yield "data: [DONE]\n\n" return Response(stream_with_context(generate()), mimetype='text/event-stream') except Exception as e: logger.error(f"Error processing chat request: {str(e)}") return jsonify({"error": f"An error occurred: {str(e)}"}), 500 # Simple API for non-streaming chat (fallback) @app.route('/chat-simple', methods=['POST']) def chat_simple(): data = request.get_json() message = data.get("message", "") if not message: return jsonify({"error": "Message is required"}), 400 try: # Create prompt with system message prompt = f"""<|im_start|>system You are a helpful, friendly, and thoughtful AI assistant. Think carefully and provide informative, detailed responses. <|im_end|> <|im_start|>user {message}<|im_end|> <|im_start|>assistant """ # Handle inputs inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) inputs = {k: v.to('cpu') for k, v in inputs.items()} # Generate answer output = model.generate( **inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True, num_beams=1, no_repeat_ngram_size=3 ) # Decode and format answer answer = tokenizer.decode(output[0], skip_special_tokens=True) # Clean up the response if "<|im_end|>" in answer: answer = answer.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip() return jsonify({"response": answer}) except Exception as e: logger.error(f"Error processing chat request: {str(e)}") return jsonify({"error": f"An error occurred: {str(e)}"}), 500 if __name__ == "__main__": try: # Initialize models at startup initialize_models() logger.info("Starting Flask application") app.run(host="0.0.0.0", port=7860) except Exception as e: logger.critical(f"Failed to start application: {str(e)}")