from transformers import pipeline, AutoTokenizer, AutoModel
from torchvision import models, transforms
from PIL import Image
import faiss

class TextClassifier:
    def __init__(self, model_name='distilbert-base-uncased'):
        self.classifier = pipeline("text-classification", model=model_name)

    def classify(self, text):
        return self.classifier(text)[0]['label']


class SentimentAnalyzer:
    def __init__(self, model_name='nlptown/bert-base-multilingual-uncased-sentiment'):
        self.analyzer = pipeline("sentiment-analysis", model=model_name)

    def analyze(self, text):
        return self.analyzer(text)[0]




class ImageRecognizer:
    def __init__(self, model_name='resnet50'):
        self.model = models.resnet50(pretrained=True)
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def recognize(self, image_path):
        image = Image.open(image_path)
        image = self.transform(image).unsqueeze(0)
        with torch.no_grad():
            outputs = self.model(image)
        _, predicted = torch.max(outputs, 1)
        return predicted.item()



class TextGenerator:
    def __init__(self, model_name='gpt2'):
        self.generator = pipeline("text-generation", model=model_name)

    def generate(self, prompt):
        response = self.generator(prompt, max_length=100, num_return_sequences=1)
        return response[0]['generated_text']




class FAQRetriever:
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.index = faiss.IndexFlatL2(384)  # Dimension of MiniLM embeddings

    def embed(self, text):
        inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
        with torch.no_grad():
            embeddings = self.model(**inputs).last_hidden_state.mean(dim=1)
        return embeddings.cpu().numpy()

    def add_faqs(self, faqs):
        self.faq_embeddings = np.concatenate([self.embed(faq) for faq in faqs])
        faiss.normalize_L2(self.faq_embeddings)
        self.index.add(self.faq_embeddings)

    def retrieve(self, query):
        query_embedding = self.embed(query)
        faiss.normalize_L2(query_embedding)
        D, I = self.index.search(query_embedding, 5)
        return I[0]  # Return top 5 FAQ indices



class CustomerSupportAssistant:
    def __init__(self):
        self.text_classifier = TextClassifier()
        self.sentiment_analyzer = SentimentAnalyzer()
        self.image_recognizer = ImageRecognizer()
        self.text_generator = TextGenerator()
        self.faq_retriever = FAQRetriever()
        self.faqs = [
            "How to reset my password?",
            "What is the return policy?",
            "How to track my order?",
            "How to contact customer support?",
            "What payment methods are accepted?"
        ]
        self.faq_retriever.add_faqs(self.faqs)

    def process_query(self, text, image_path=None):
        topic = self.text_classifier.classify(text)
        sentiment = self.sentiment_analyzer.analyze(text)
        if image_path:
            image_info = self.image_recognizer.recognize(image_path)
        else:
            image_info = "No image provided."
        faqs = self.faq_retriever.retrieve(text)
        faq_responses = [self.faqs[i] for i in faqs]
        response_prompt = f"Topic: {topic}, Sentiment: {sentiment['label']} with confidence {sentiment['score']}. FAQs: {faq_responses}. Image info: {image_info}. Generate a response."
        response = self.text_generator.generate(response_prompt)
        return response

# Example usage:
assistant = CustomerSupportAssistant()
input_text = "I'm having trouble with my recent order."
output = assistant.process_query(input_text)
print(output)