ChatBot / app.py
sanjeevbora's picture
authentication
ba9e337 verified
raw
history blame
6.06 kB
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
from langchain_community.document_loaders import DirectoryLoader
import torch
import re
import requests
from urllib.parse import urlencode, parse_qs, urlparse
import transformers
import spaces
# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
loader = DirectoryLoader('./example', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory="companies_db")
books_db = Chroma(persist_directory="./companies_db", embedding_function=embeddings)
books_db_client = books_db.as_retriever()
# Initialize the model and tokenizer
model_name = "stabilityai/stablelm-zephyr-3b"
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
config=model_config,
device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=True,
torch_dtype=torch.float16,
device_map=device,
do_sample=True,
temperature=0.7,
top_p=0.9,
top_k=50,
max_new_tokens=256
)
llm = HuggingFacePipeline(pipeline=query_pipeline)
books_db_client_retriever = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=books_db_client,
verbose=True
)
# OAuth Configuration
TENANT_ID = '2b093ced-2571-463f-bc3e-b4f8bcb427ee'
CLIENT_ID = '2a7c884c-942d-49e2-9e5d-7a29d8a0d3e5'
CLIENT_SECRET = 'EOF8Q~kKHCRgx8tnlLM-H8e93ifetxI6x7sU6bGW'
REDIRECT_URI = 'https://sanjeevbora-chatbot.hf.space/'
AUTH_URL = f"https://login.microsoftonline.com/2b093ced-2571-463f-bc3e-b4f8bcb427ee/oauth2/v2.0/authorize"
TOKEN_URL = f"https://login.microsoftonline.com/2b093ced-2571-463f-bc3e-b4f8bcb427ee/oauth2/v2.0/token"
GRAPH_API_URL = "https://graph.microsoft.com/v1.0/me"
# Function to redirect to Microsoft login
def get_login_url():
params = {
'client_id': CLIENT_ID,
'response_type': 'code',
'redirect_uri': REDIRECT_URI,
'response_mode': 'query',
'scope': 'User.Read',
'state': '12345' # Optional state parameter for CSRF protection
}
login_url = f"{AUTH_URL}?{urlencode(params)}"
return login_url
# Function to exchange auth code for an access token
def exchange_code_for_token(auth_code):
data = {
'grant_type': 'authorization_code',
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'code': auth_code,
'redirect_uri': REDIRECT_URI
}
response = requests.post(TOKEN_URL, data=data)
token_data = response.json()
return token_data.get('access_token')
# Step 3: Function to get user profile
def get_user_profile(access_token):
headers = {
'Authorization': f'Bearer {access_token}'
}
response = requests.get(GRAPH_API_URL, headers=headers)
return response.json()
# Function to handle OAuth callback
def handle_oauth_callback(url):
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
auth_code = query_params.get('code', [None])[0]
if auth_code:
access_token = exchange_code_for_token(auth_code)
user_profile = get_user_profile(access_token)
return user_profile
else:
return "Authorization failed."
# Function to retrieve answer using the RAG system
@spaces.GPU(duration=60)
def test_rag(query):
books_retriever = books_db_client_retriever.run(query)
# Extract the relevant answer using regex
corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
if corrected_text_match:
corrected_text_books = corrected_text_match.group(1).strip()
else:
corrected_text_books = "No helpful answer found."
return corrected_text_books
# Function for RAG Chat
def chat(query, history=None):
if history is None:
history = []
if query:
answer = test_rag(query)
history.append((query, answer))
return history, "" # Clear input after submission
# Gradio interface
with gr.Blocks() as interface:
gr.Markdown("## RAG Chatbot")
gr.Markdown("Ask a question and get answers based on retrieved documents.")
input_box = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
submit_btn = gr.Button("Submit")
chat_history = gr.Chatbot(label="Chat History")
# Add Microsoft OAuth Login
auth_btn = gr.Button("Login with Microsoft")
# OAuth callback URL input (for demonstration, replace with actual callback handler)
callback_url = gr.Textbox(label="OAuth Callback URL", placeholder="Paste the callback URL here...")
# Display user profile after login
profile_output = gr.JSON(label="User Profile")
# Action for OAuth login
def login_action():
return gr.redirect(get_login_url())
# Action for handling OAuth callback and displaying the user profile
def handle_callback_action(url):
user_profile = handle_oauth_callback(url)
return user_profile
# Bind login action to button
auth_btn.click(login_action)
# Bind OAuth callback handler to the callback input
callback_url.change(handle_callback_action, inputs=[callback_url], outputs=[profile_output])
# Submit action for chat
submit_btn.click(chat, inputs=[input_box, chat_history], outputs=[chat_history, input_box])
interface.launch()