from flask import Flask, request, render_template, send_from_directory
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
from gtts import gTTS
import os
import soundfile as sf
from transformers import VitsTokenizer, VitsModel, set_seed
from IndicTransToolkit import IndicProcessor

# Initialize Flask app
app = Flask(__name__)
UPLOAD_FOLDER = "./static/uploads/"
AUDIO_FOLDER = "./static/audio/"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(AUDIO_FOLDER, exist_ok=True)
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
app.config["AUDIO_FOLDER"] = AUDIO_FOLDER

# Load models
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda" if torch.cuda.is_available() else "cpu")
model_name = "ai4bharat/indictrans2-en-indic-1B"
tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
model_IT2 = torch.quantization.quantize_dynamic(
    model_IT2, {torch.nn.Linear}, dtype=torch.qint8
)
model_IT2.to("cuda" if torch.cuda.is_available() else "cpu")
ip = IndicProcessor(inference=True)

# Functions
def generate_caption(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        generated_ids = blip_model.generate(**inputs)
    return blip_processor.decode(generated_ids[0], skip_special_tokens=True)

def translate_caption(caption, target_languages):
    src_lang = "eng_Latn"
    input_sentences = [caption]
    translations = {}

    for tgt_lang in target_languages:
        batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
        inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
        with torch.no_grad():
            generated_tokens = model_IT2.generate(
                **inputs, min_length=0, max_length=256, num_beams=5, num_return_sequences=1
            )
        with tokenizer_IT2.as_target_tokenizer():
            translated_tokens = tokenizer_IT2.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
        translations[tgt_lang] = ip.postprocess_batch(translated_tokens, lang=tgt_lang)[0]
    return translations

def generate_audio_gtts(text, lang_code, output_file):
    tts = gTTS(text=text, lang=lang_code)
    tts.save(output_file)
    return output_file

@app.route("/", methods=["GET", "POST"])
def index():
    if request.method == "POST":
        image_file = request.files.get("image")
        if image_file:
            image_path = os.path.join(app.config["UPLOAD_FOLDER"], image_file.filename)
            image_file.save(image_path)

            caption = generate_caption(image_path)
            target_languages = request.form.getlist("languages")
            translations = translate_caption(caption, target_languages)

            audio_files = {}
            lang_codes = {
                "hin_Deva": "hi", "guj_Gujr": "gu", "urd_Arab": "ur", "mar_Deva": "mr"
            }
            for lang, translation in translations.items():
                lang_code = lang_codes.get(lang, "en")
                audio_file_path = os.path.join(app.config["AUDIO_FOLDER"], f"{lang}.mp3")
                audio_files[lang] = generate_audio_gtts(translation, lang_code, audio_file_path)

            return render_template(
                "index.html", image_path=image_path, caption=caption, translations=translations, audio_files=audio_files
            )
    return render_template("index.html")

@app.route("/audio/<filename>")
def audio(filename):
    return send_from_directory(app.config["AUDIO_FOLDER"], filename)

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)