Mishal23 commited on
Commit
ec99f56
·
verified ·
1 Parent(s): ef30302

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import libraries
2
+ import pandas as pd
3
+ from fastapi import FastAPI
4
+ from fastapi.responses import HTMLResponse
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from sentence_transformers import SentenceTransformer
7
+ import faiss
8
+ from datasets import load_dataset
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ import torch
11
+ import gradio as gr
12
+
13
+ # Load the Dataset from Hugging Face and FAQ CSV
14
+ support_data = load_dataset("rjac/e-commerce-customer-support-qa")
15
+
16
+ # Load FAQ data from a local CSV file directly
17
+ faq_data = pd.read_csv("Ecommerce_FAQs.csv")
18
+
19
+ # Preprocess and Clean Data
20
+ faq_data.rename(columns={'prompt': 'Question', 'response': 'Answer'}, inplace=True)
21
+ faq_data = faq_data[['Question', 'Answer']]
22
+ support_data_df = pd.DataFrame(support_data['train'])
23
+
24
+ # Extract question-answer pairs from the conversation field
25
+ def extract_conversation(data):
26
+ try:
27
+ parts = data.split("\n\n")
28
+ question = parts[1].split(": ", 1)[1] if len(parts) > 1 else ""
29
+ answer = parts[2].split(": ", 1)[1] if len(parts) > 2 else ""
30
+ return pd.Series({"Question": question, "Answer": answer})
31
+ except IndexError:
32
+ return pd.Series({"Question": "", "Answer": ""})
33
+
34
+ # Apply extraction function
35
+ support_data_df[['Question', 'Answer']] = support_data_df['conversation'].apply(extract_conversation)
36
+
37
+ # Combine FAQ data with support data
38
+ combined_data = pd.concat([faq_data, support_data_df[['Question', 'Answer']]], ignore_index=True)
39
+
40
+ # Initialize SBERT Model
41
+ model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
42
+
43
+ # Generate and Index Embeddings for Combined Data
44
+ questions = combined_data['Question'].tolist()
45
+ embeddings = model.encode(questions, convert_to_tensor=True)
46
+
47
+ # Create FAISS index
48
+ index = faiss.IndexFlatL2(embeddings.shape[1])
49
+ index.add(embeddings.cpu().numpy())
50
+
51
+ # Load your fine-tuned DialoGPT model and tokenizer
52
+ tokenizer_gpt = AutoTokenizer.from_pretrained("Mishal23/fine_tuned_dialoGPT_model") # Update with your fine-tuned model path
53
+ model_gpt = AutoModelForCausalLM.from_pretrained("Mishal23/fine_tuned_dialoGPT_model") # Update with your fine-tuned model path
54
+
55
+ # Define Retrieval Function
56
+ def retrieve_answer(question):
57
+ question_embedding = model.encode([question], convert_to_tensor=True)
58
+ question_embedding_np = question_embedding.cpu().numpy()
59
+ _, closest_index = index.search(question_embedding_np, k=1)
60
+ best_match_idx = closest_index[0][0]
61
+ answer = combined_data.iloc[best_match_idx]['Answer']
62
+
63
+ # If the answer is empty, generate a fallback response
64
+ if answer.strip() == "":
65
+ return generate_response(question) # Generate a response from DialoGPT
66
+
67
+ return answer
68
+
69
+ # Generate response using your fine-tuned DialoGPT model
70
+ def generate_response(user_input):
71
+ input_ids = tokenizer_gpt.encode(user_input, return_tensors='pt')
72
+ chat_history_ids = model_gpt.generate(input_ids, max_length=100, pad_token_id=tokenizer_gpt.eos_token_id)
73
+ response = tokenizer_gpt.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
74
+ return response if response.strip() else "Oops, I don't know the answer to that."
75
+
76
+ # Initialize FastAPI app
77
+ app = FastAPI()
78
+
79
+ # Add CORS middleware
80
+ app.add_middleware(
81
+ CORSMiddleware,
82
+ allow_origins=["*"], # Allows all origins
83
+ allow_credentials=True,
84
+ allow_methods=["*"],
85
+ allow_headers=["*"],
86
+ )
87
+
88
+ # Define FastAPI route for Gradio interface
89
+ @app.get("/")
90
+ async def read_root():
91
+ return HTMLResponse("""<html>
92
+ <head>
93
+ <title>E-commerce Support Chatbot</title>
94
+ </head>
95
+ <body>
96
+ <h1>Welcome to the E-commerce Support Chatbot</h1>
97
+ <p>Use the Gradio interface to chat with the bot!</p>
98
+ </body>
99
+ </html>""")
100
+
101
+ # Gradio Chat Interface for E-commerce Support Chatbot
102
+ def chatbot_interface(user_input, chat_history=[]):
103
+ # Retrieve response from the knowledge base or generate it
104
+ response = retrieve_answer(user_input)
105
+ chat_history.append(("User", user_input))
106
+ chat_history.append(("Bot", response))
107
+
108
+ # Format chat history for display
109
+ chat_display = []
110
+ for sender, message in chat_history:
111
+ if sender == "User":
112
+ chat_display.append(f"**You**: {message}")
113
+ else:
114
+ chat_display.append(f"**Bot**: {message}")
115
+ return "\n\n".join(chat_display), chat_history
116
+
117
+ # Set up Gradio Chat Interface with conversational format
118
+ iface = gr.Interface(
119
+ fn=chatbot_interface,
120
+ inputs=[
121
+ gr.Textbox(lines=2, placeholder="Type your question here..."),
122
+ gr.State([]) # State variable to maintain chat history
123
+ ],
124
+ outputs=[
125
+ gr.Markdown(), # Display formatted chat history
126
+ gr.State() # Update state
127
+ ],
128
+ title="E-commerce Support Chatbot",
129
+ description="Ask questions about order tracking, returns, account help, and more!",
130
+ )
131
+
132
+ # Launch Gradio interface directly
133
+ iface.launch()