from flask import Flask, request, jsonify, render_template, Response
import os
from dotenv import load_dotenv
import json
import requests
from PIL import Image
import base64
from io import BytesIO
import logging
import time

# 加载环境变量
load_dotenv()

app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def create_headers():
    """创建API请求头"""
    return {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
    }

def image_to_base64(image):
    """将PIL图像转换为base64"""
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

def process_image_with_vision(image_base64):
    """使用Vision API处理图像"""
    url = f"{os.getenv('OPENAI_API_BASE')}/v1/chat/completions"
    
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": """请仔细分析这张图片中的题目,要求:

1. 识别内容类型:
   - 判断是否为数学题、物理题、化学题等
   - 识别是选择题、填空题还是解答题
   - 判断是否包含图表或特殊符号

2. 内容提取:
   - 提取所有文字内容,包括题干、选项(如果有)
   - 识别所有数学公式、化学方程式或特殊符号
   - 保留原有的排版格式(如换行、缩进等)

3. 公式处理:
   - 将所有数学公式转换为LaTeX格式
   - 使用[formula_n]作为占位符,其中n为公式编号
   - 保持公式的完整性和准确性

请使用以下JSON格式返回:
{
    "type": "题目类型(如:数学/物理/化学)",
    "format": "题目格式(如:选择题/填空题/解答题)",
    "text": "包含[formula_1], [formula_2]等占位符的完整文本",
    "formulas": ["latex公式1", "latex公式2"],
    "options": ["A. xxx", "B. xxx"] // 如果是选择题则包含此字段
    "notes": ["可能存在的问题说明1", "问题说明2"]
}"""
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{image_base64}"
                    }
                }
            ]
        }
    ]
    
    payload = {
        "model": os.getenv('OPENAI_VISION_MODEL'),
        "messages": messages,
        "max_tokens": 1500,
        "temperature": 0.2
    }
    
    try:
        response = requests.post(url, headers=create_headers(), json=payload)
        response.raise_for_status()
        result = response.json()
        
        content = result['choices'][0]['message']['content']
        try:
            return json.loads(content)
        except json.JSONDecodeError:
            return {
                "type": "unknown",
                "format": "unknown",
                "text": content,
                "formulas": [],
                "notes": ["无法解析为JSON格式"]
            }
            
    except Exception as e:
        logger.error(f"Vision API调用错误: {str(e)}")
        return {"error": str(e)}

def handle_sse_response(raw_data):
    """处理SSE响应数据"""
    if raw_data:
        try:
            data = json.loads(raw_data)
            if len(data['choices']) > 0:
                delta = data['choices'][0].get('delta', {})
                content = delta.get('content', '')
                print(f"Stream content: {content}", flush=True)  # 直接打印流式内容
                return content
        except json.JSONDecodeError:
            pass
    return ''

def stream_solve(problem_text, formulas):
    url = f"{os.getenv('OPENAI_API_BASE')}/v1/chat/completions"
    
    full_problem = problem_text
    for i, formula in enumerate(formulas, 1):
        full_problem = full_problem.replace(f"[formula_{i}]", f"$${formula}$$")
    
    print(f"\n问题内容: {full_problem}")
    
    messages = [
    {
        "role": "system",
        "content": """请按照以下格式回答问题:
        [开始解答]
        这是答案:(写出明确的答案,如"答案为:2.5米","答案选 A"等)
        [解答结束]

        [开始解析]
        这是解析:(写出详细的解析过程,需严格遵循以下要求:

        解题思路要求:
        1. 开篇明确指出题目涉及的核心知识点和解题方向
        2. 采用合理的解题策略,并说明策略选择的依据
        3. 对于多解法题目,说明最优解法的选择理由

        专业性要求:
        1. 严格使用规范的学科专业术语,避免口语化表达
        2. 数学公式、物理量、化学方程式等必须符合学科规范
        3. 计算步骤要详实,重要的中间步骤不得省略
        4. 注意数据的有效位数,物理量单位的规范性

        逻辑性要求:
        1. 解题步骤要层次分明,逻辑推导严谨
        2. 关键结论要有充分的推导过程和理论依据
        3. 复杂问题应合理拆分为子步骤,循序渐进

        教学性要求:
        1. 对重要概念和关键步骤要适当添加解释说明
        2. 特别标注解题中的难点和易错点
        3. 适时补充相关的知识点拓展
        4. 针对典型错误进行分析和提醒

        书写规范:
        1. 专业符号和公式要规范,确保清晰美观
        2. 合理使用缩进和分段,突出层次结构
        3. 保持语言的严谨性和专业性)
        [解析结束]"""
    },
    {
        "role": "user",
        "content": full_problem
    }
    ]
    
    payload = {
        "model": os.getenv('OPENAI_CHAT_MODEL'),
        "messages": messages,
        "stream": True,
        "temperature": 0.3
    }
    
    try:
        response = requests.post(
            url,
            headers=create_headers(),
            json=payload,
            stream=True
        )
        response.raise_for_status()

        complete_response = ""
        
        for line in response.iter_lines():
            if not line or not line.startswith(b'data: '):
                continue
                
            try:
                data = json.loads(line[6:].decode('utf-8'))
                content = data['choices'][0].get('delta', {}).get('content', '')
                complete_response += content
                
            except json.JSONDecodeError:
                continue
                    
    except Exception as e:
        print(f"\n错误: {str(e)}")
        yield f"data: {json.dumps({'error': str(e)})}\n\n"
        return

    # 提取答案和解析
    try:
        if '[开始解答]' in complete_response and '[开始解析]' in complete_response:
            parts = complete_response.split('[开始解答]')
            answer_part = parts[1].split('[解答结束]')[0].strip()
            
            parts = complete_response.split('[开始解析]')
            analysis_part = parts[1].split('[解析结束]')[0].strip()
            
            # 控制台打印
            print("\n完整回答:")
            print(f"答案: {answer_part}")
            print(f"解析: {analysis_part}")
            
            # 返回结果给前端
            yield f"data: {json.dumps({'type': 'answer', 'content': answer_part})}\n\n"
            yield f"data: {json.dumps({'type': 'analysis', 'content': analysis_part})}\n\n"
            
        else:
            print("\n响应格式不符合预期")
            yield f"data: {json.dumps({'error': '响应格式不符合预期'})}\n\n"
            
    except Exception as e:
        print(f"\n解析错误: {str(e)}")
        yield f"data: {json.dumps({'error': f'解析响应时出错: {str(e)}'})}\n\n"

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/process', methods=['POST'])
def process():
    if 'file' not in request.files:
        return jsonify({'error': '没有文件上传'}), 400
    
    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': '没有选择文件'}), 400
    
    try:
        image = Image.open(file)
        image_base64 = image_to_base64(image)
        
        result = process_image_with_vision(image_base64)
        
        if 'error' in result:
            return jsonify({'error': result['error']}), 500
        
        return jsonify({
            'original_image': image_base64,
            'result': result
        })
        
    except Exception as e:
        logger.error(f"处理图像错误: {str(e)}")
        return jsonify({'error': str(e)}), 500

@app.route('/solve', methods=['POST'])
def solve():
    data = request.json
    if not data or 'text' not in data or 'formulas' not in data:
        return jsonify({'error': '无效的请求数据'}), 400
    
    return Response(
        stream_solve(data['text'], data['formulas']),
        content_type='text/event-stream'
    )

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=True)