Spaces:
Running
Running
import gradio as gr | |
import logging | |
import requests | |
import time | |
from bs4 import BeautifulSoup | |
from datetime import datetime | |
from typing import List, Optional, Tuple | |
from urllib.parse import urljoin, urlparse | |
import random | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
import PyPDF2 | |
import io | |
from joblib import dump, load | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
from icalendar import Calendar | |
from requests.adapters import HTTPAdapter | |
from requests.packages.urllib3.util.retry import Retry | |
from fake_useragent import UserAgent | |
from concurrent.futures import ThreadPoolExecutor | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
# Download NLTK data | |
try: | |
nltk.download('punkt', quiet=True) | |
except Exception as e: | |
logger.warning(f"Failed to download NLTK data: {e}") | |
class Config: | |
MODEL_NAME = "microsoft/DialoGPT-medium" | |
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
MAX_TOKENS = 1000 | |
REQUEST_TIMEOUT = 10 | |
MAX_DEPTH = 1 | |
SIMILARITY_THRESHOLD = 0.5 | |
CHUNK_SIZE = 512 | |
MAX_WORKERS = 5 | |
INDEXED_URLS = { | |
"https://drive.google.com/file/d/1d5kkqaQkdiA2SwJ0JFrTuKO9zauiUtFz/view?usp=sharing" | |
} | |
class ResourceItem: | |
def __init__(self, url: str, content: str, resource_type: str): | |
self.url = url | |
self.content = content | |
self.type = resource_type | |
self.embedding = None | |
self.chunks = [] | |
self.chunk_embeddings = [] | |
def __str__(self): | |
return f"ResourceItem(type={self.type}, url={self.url}, content_length={len(self.content)})" | |
def create_chunks(self, chunk_size=Config.CHUNK_SIZE): | |
"""Split content into overlapping chunks for better context preservation""" | |
words = self.content.split() | |
overlap = chunk_size // 4 # 25% overlap | |
for i in range(0, len(words), chunk_size - overlap): | |
chunk = ' '.join(words[i:i + chunk_size]) | |
if chunk: | |
self.chunks.append(chunk) | |
class RobustCrawler: | |
def __init__(self, max_retries=3, backoff_factor=0.3): | |
self.ua = UserAgent() | |
self.session = self._create_robust_session(max_retries, backoff_factor) | |
def _create_robust_session(self, max_retries, backoff_factor): | |
session = requests.Session() | |
retry_strategy = Retry( | |
total=max_retries, | |
status_forcelist=[429, 500, 502, 503, 504], | |
method_whitelist=["HEAD", "GET", "OPTIONS"], | |
backoff_factor=backoff_factor, | |
raise_on_status=False | |
) | |
adapter = HTTPAdapter(max_retries=retry_strategy) | |
session.mount("https://", adapter) | |
session.mount("http://", adapter) | |
return session | |
def get_headers(self): | |
return { | |
"User-Agent": self.ua.random, | |
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", | |
"Accept-Language": "en-US,en;q=0.5", | |
"Referer": "https://www.google.com/", | |
"DNT": "1", | |
"Connection": "keep-alive", | |
"Upgrade-Insecure-Requests": "1" | |
} | |
def crawl_with_exponential_backoff(self, url, timeout=Config.REQUEST_TIMEOUT): | |
try: | |
time.sleep(random.uniform(0.5, 2.0)) | |
response = self.session.get( | |
url, | |
headers=self.get_headers(), | |
timeout=timeout | |
) | |
response.raise_for_status() | |
return response | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Crawling error for {url}: {e}") | |
return None | |
class SchoolChatbot: | |
def __init__(self): | |
logger.info("Initializing SchoolChatbot...") | |
self.setup_models() | |
self.resources = [] | |
self.visited_urls = set() | |
self.crawl_and_index_resources() | |
def setup_models(self): | |
try: | |
logger.info("Setting up models...") | |
self.tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME) | |
self.model = AutoModelForCausalLM.from_pretrained(Config.MODEL_NAME) | |
self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL) | |
logger.info("Models setup completed successfully.") | |
except Exception as e: | |
logger.error(f"Failed to setup models: {e}") | |
raise RuntimeError("Failed to initialize required models") | |
def crawl_and_index_resources(self): | |
logger.info("Starting to crawl and index resources...") | |
with ThreadPoolExecutor(max_workers=Config.MAX_WORKERS) as executor: | |
futures = [executor.submit(self.crawl_url, url, 0) for url in Config.INDEXED_URLS] | |
for future in futures: | |
try: | |
future.result() | |
except Exception as e: | |
logger.error(f"Error in crawling thread: {e}") | |
logger.info(f"Crawling completed. Indexed {len(self.resources)} resources.") | |
def crawl_url(self, url, depth): | |
if depth > Config.MAX_DEPTH or url in self.visited_urls: | |
return | |
self.visited_urls.add(url) | |
crawler = RobustCrawler() | |
response = crawler.crawl_with_exponential_backoff(url) | |
if not response: | |
logger.error(f"Failed to retrieve content from {url}. Please check the URL and permissions.") | |
return | |
content_type = response.headers.get("Content-Type", "").lower() | |
try: | |
if "text/calendar" in content_type or url.endswith(".ics"): | |
self.extract_ics_content(url, response.text) | |
elif "text/html" in content_type: | |
self.extract_html_content(url, response) | |
elif "application/pdf" in content_type: | |
self.extract_pdf_content(url, response.content) | |
else: | |
logger.warning(f"Unknown content type for {url}: {content_type}") | |
self.store_resource(url, response.text, 'unknown') | |
except Exception as e: | |
logger.error(f"Error processing {url}: {e}") | |
def extract_ics_content(self, url, ics_text): | |
try: | |
cal = Calendar.from_ical(ics_text) | |
events = [] | |
for component in cal.walk(): | |
if component.name == "VEVENT": | |
event = self._format_calendar_event(component) | |
if event: | |
events.append(event) | |
if events: | |
self.store_resource(url, "\n".join(events), 'calendar') | |
except Exception as e: | |
logger.error(f"Error parsing ICS from {url}: {e}") | |
def _format_calendar_event(self, event): | |
try: | |
summary = event.get("SUMMARY", "No Summary") | |
start = event.get("DTSTART", "").dt | |
end = event.get("DTEND", "").dt | |
description = event.get("DESCRIPTION", "") | |
location = event.get("LOCATION", "") | |
event_details = [f"Event: {summary}"] | |
if start: | |
event_details.append(f"Start: {start}") | |
if end: | |
event_details.append(f"End: {end}") | |
if location: | |
event_details.append(f"Location: {location}") | |
if description: | |
event_details.append(f"Description: {description}") | |
return " | ".join(event_details) | |
except Exception: | |
return None | |
def extract_html_content(self, url, response): | |
try: | |
soup = BeautifulSoup(response.content, 'html.parser') | |
# Remove unwanted elements | |
for element in soup.find_all(['script', 'style', 'nav', 'footer']): | |
element.decompose() | |
content_sections = [] | |
# Extract main content | |
main_content = soup.find(['main', 'article', 'div'], class_=['content', 'main-content']) | |
if main_content: | |
content_sections.append(main_content.get_text(strip=True, separator=' ')) | |
# Extract headings and their associated content | |
for heading in soup.find_all(['h1', 'h2', 'h3']): | |
section = [heading.get_text(strip=True)] | |
next_elem = heading.find_next_sibling() | |
while next_elem and next_elem.name in ['p', 'ul', 'ol', 'div']: | |
section.append(next_elem.get_text(strip=True)) | |
next_elem = next_elem.find_next_sibling() | |
content_sections.append(' '.join(section)) | |
if content_sections: | |
self.store_resource(url, ' '.join(content_sections), 'webpage') | |
# Process links if within depth limit | |
if len(self.visited_urls) < Config.MAX_DEPTH: | |
self._process_links(soup, url) | |
except Exception as e: | |
logger.error(f"Error extracting HTML content from {url}: {e}") | |
def _process_links(self, soup, base_url): | |
try: | |
for link in soup.find_all('a', href=True): | |
full_url = urljoin(base_url, link['href']) | |
if self.is_valid_url(full_url) and full_url not in self.visited_urls: | |
time.sleep(random.uniform(0.5, 2.0)) | |
self.crawl_url(full_url, len(self.visited_urls)) | |
except Exception as e: | |
logger.error(f"Error processing links from {base_url}: {e}") | |
def extract_pdf_content(self, url, pdf_content): | |
try: | |
pdf_file = io.BytesIO(pdf_content) | |
pdf_reader = PyPDF2.PdfReader(pdf_file) | |
text_content = [] | |
for page in pdf_reader.pages: | |
try: | |
text_content.append(page.extract_text()) | |
except Exception as e: | |
logger.error(f"Error extracting text from PDF page: {e}") | |
continue | |
if text_content: | |
self.store_resource(url, ' '.join(text_content), 'pdf') | |
except Exception as e: | |
logger.error(f"Error extracting PDF content from {url}: {e}") | |
def store_resource(self, url, text_data, resource_type): | |
try: | |
# Create resource item and split into chunks | |
item = ResourceItem(url, text_data, resource_type) | |
item.create_chunks() | |
# Generate embeddings for chunks | |
item.chunk_embeddings = [ | |
self.embedding_model.encode(chunk) | |
for chunk in item.chunks | |
] | |
# Calculate average embedding | |
if item.chunk_embeddings: | |
item.embedding = np.mean(item.chunk_embeddings, axis=0) | |
self.resources.append(item) | |
logger.debug(f"Stored resource: {url} (type={resource_type})") | |
except Exception as e: | |
logger.error(f"Error storing resource {url}: {e}") | |
def is_valid_url(self, url): | |
try: | |
parsed = urlparse(url) | |
return bool(parsed.scheme) and bool(parsed.netloc) | |
except Exception: | |
return False | |
def find_best_matching_chunks(self, query, n_chunks=3): | |
if not self.resources: | |
return [] | |
try: | |
query_embedding = self.embedding_model.encode(query) | |
all_chunks = [] | |
for resource in self.resources: | |
for chunk, embedding in zip(resource.chunks, resource.chunk_embeddings): | |
score = cosine_similarity([query_embedding], [embedding])[0][0] | |
if score > Config.SIMILARITY_THRESHOLD: | |
all_chunks.append((chunk, score, resource.url)) | |
# Sort by similarity score and get top n chunks | |
all_chunks.sort(key=lambda x: x[1], reverse=True) | |
return all_chunks[:n_chunks] | |
except Exception as e: | |
logger.error(f"Error finding matching chunks: {e}") | |
return [] | |
def generate_response(self, user_input): | |
try: | |
# Find best matching chunks | |
best_chunks = self.find_best_matching_chunks(user_input) | |
if not best_chunks: | |
return "I apologize, but I couldn't find any relevant information in my knowledge base. Could you please rephrase your question or ask about something else?" | |
# Prepare context from best matching chunks | |
context = "\n".join([chunk[0] for chunk in best_chunks]) | |
# Prepare conversation history | |
conversation = f"Context: {context}\nUser: {user_input}\nAssistant:" | |
# Generate response | |
input_ids = self.tokenizer.encode(conversation, return_tensors='pt') | |
response_ids = self.model.generate( | |
input_ids, | |
max_length=Config.MAX_TOKENS, | |
pad_token_id=self.tokenizer.eos_token_id, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
response = self.tokenizer.decode( | |
response_ids[:, input_ids.shape[-1]:][0], | |
skip_special_tokens=True | |
) | |
# Format response with source | |
source_urls = list(set(chunk[2] for chunk in best_chunks)) | |
sources = "\n\nSources:\n" + "\n".join(source_urls) | |
return response + sources | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return "I apologize, but I encountered an error while processing your question. Please try again." | |
def create_gradio_interface(chatbot): | |
def respond(user_input): | |
return chatbot.generate_response(user_input) | |
interface = gr.Interface( | |
fn=respond, | |
inputs=gr.Textbox( | |
label="Ask a Question", | |
placeholder="Type your question here...", | |
lines=2 | |
), | |
outputs=gr.Textbox( | |
label="Answer", | |
placeholder="Response will appear here...", | |
lines=5 | |
), | |
title="School Information Chatbot", | |
description="Ask about school events, policies, or other information. The chatbot will provide answers based on available school documents and resources.", | |
examples=[ | |
["What events are happening this week?"], | |
["When is the next board meeting?"], | |
["What is the school's attendance policy?"] | |
], | |
theme=gr.themes.Soft(), | |
flagging_mode="never" | |
) | |
return interface | |
if __name__ == "__main__": | |
try: | |
chatbot = SchoolChatbot() | |
interface = create_gradio_interface(chatbot) | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
debug=True | |
) | |
except Exception as e: | |
logger.error(f"Failed to start application: {e}") | |