File size: 4,926 Bytes
ec99f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Import libraries
import pandas as pd
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
import faiss
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr

# Load the Dataset from Hugging Face and FAQ CSV
support_data = load_dataset("rjac/e-commerce-customer-support-qa")

# Load FAQ data from a local CSV file directly
faq_data = pd.read_csv("Ecommerce_FAQs.csv")

# Preprocess and Clean Data
faq_data.rename(columns={'prompt': 'Question', 'response': 'Answer'}, inplace=True)
faq_data = faq_data[['Question', 'Answer']]
support_data_df = pd.DataFrame(support_data['train'])

# Extract question-answer pairs from the conversation field
def extract_conversation(data):
    try:
        parts = data.split("\n\n")
        question = parts[1].split(": ", 1)[1] if len(parts) > 1 else ""
        answer = parts[2].split(": ", 1)[1] if len(parts) > 2 else ""
        return pd.Series({"Question": question, "Answer": answer})
    except IndexError:
        return pd.Series({"Question": "", "Answer": ""})

# Apply extraction function
support_data_df[['Question', 'Answer']] = support_data_df['conversation'].apply(extract_conversation)

# Combine FAQ data with support data
combined_data = pd.concat([faq_data, support_data_df[['Question', 'Answer']]], ignore_index=True)

# Initialize SBERT Model
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

# Generate and Index Embeddings for Combined Data
questions = combined_data['Question'].tolist()
embeddings = model.encode(questions, convert_to_tensor=True)

# Create FAISS index
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.cpu().numpy())

# Load your fine-tuned DialoGPT model and tokenizer
tokenizer_gpt = AutoTokenizer.from_pretrained("Mishal23/fine_tuned_dialoGPT_model")  # Update with your fine-tuned model path
model_gpt = AutoModelForCausalLM.from_pretrained("Mishal23/fine_tuned_dialoGPT_model")  # Update with your fine-tuned model path

# Define Retrieval Function
def retrieve_answer(question):
    question_embedding = model.encode([question], convert_to_tensor=True)
    question_embedding_np = question_embedding.cpu().numpy()
    _, closest_index = index.search(question_embedding_np, k=1)
    best_match_idx = closest_index[0][0]
    answer = combined_data.iloc[best_match_idx]['Answer']

    # If the answer is empty, generate a fallback response
    if answer.strip() == "":
        return generate_response(question)  # Generate a response from DialoGPT

    return answer

# Generate response using your fine-tuned DialoGPT model
def generate_response(user_input):
    input_ids = tokenizer_gpt.encode(user_input, return_tensors='pt')
    chat_history_ids = model_gpt.generate(input_ids, max_length=100, pad_token_id=tokenizer_gpt.eos_token_id)
    response = tokenizer_gpt.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
    return response if response.strip() else "Oops, I don't know the answer to that."

# Initialize FastAPI app
app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Define FastAPI route for Gradio interface
@app.get("/")
async def read_root():
    return HTMLResponse("""<html>
        <head>
            <title>E-commerce Support Chatbot</title>
        </head>
        <body>
            <h1>Welcome to the E-commerce Support Chatbot</h1>
            <p>Use the Gradio interface to chat with the bot!</p>
        </body>
    </html>""")

# Gradio Chat Interface for E-commerce Support Chatbot
def chatbot_interface(user_input, chat_history=[]):
    # Retrieve response from the knowledge base or generate it
    response = retrieve_answer(user_input)
    chat_history.append(("User", user_input))
    chat_history.append(("Bot", response))

    # Format chat history for display
    chat_display = []
    for sender, message in chat_history:
        if sender == "User":
            chat_display.append(f"**You**: {message}")
        else:
            chat_display.append(f"**Bot**: {message}")
    return "\n\n".join(chat_display), chat_history

# Set up Gradio Chat Interface with conversational format
iface = gr.Interface(
    fn=chatbot_interface,
    inputs=[
        gr.Textbox(lines=2, placeholder="Type your question here..."),
        gr.State([])  # State variable to maintain chat history
    ],
    outputs=[
        gr.Markdown(),  # Display formatted chat history
        gr.State()  # Update state
    ],
    title="E-commerce Support Chatbot",
    description="Ask questions about order tracking, returns, account help, and more!",
)

# Launch Gradio interface directly
iface.launch()