import gradio as gr import torch import os from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import yfinance as yf from datetime import datetime from huggingface_hub import login # Get Hugging Face token from environment variables hf_token = os.environ.get("HF_TOKEN") if hf_token: # Login to Hugging Face login(token=hf_token) print("Successfully logged in to Hugging Face") else: print("WARNING: HF_TOKEN not found in environment variables. You may face access issues for gated models.") # Load the Chatbot Model and Tokenizer model_name = "Akshit-77/llama-3.2-3b-chatbot" # Try to use local cache if already downloaded, otherwise load with auth token try: tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) except Exception as e: print(f"Error loading tokenizer: {e}") raise # Configure quantization for memory efficiency bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="fp4", bnb_4bit_compute_dtype="float16" ) # Load model with quantization and auth token try: model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto", token=hf_token ) print("Model loaded successfully") except Exception as e: print(f"Error loading model: {e}") raise # CSS for Chatbot UI css = """ #chatbot { font-family: Arial, sans-serif; background-color: #e5ddd5; } .message { padding: 10px 15px; border-radius: 7.5px; margin: 5px 0; max-width: 75%; position: relative; } .user-message { background: #dcf8c6; margin-left: auto; margin-right: 10px; } .bot-message { background: white; margin-left: 10px; } .timestamp { font-size: 0.7em; color: #667781; float: right; margin-left: 10px; margin-top: 3px; } """ class StockDataRetriever: def __init__(self): self.stock_mapping = { # Existing mappings 'RELIANCE': 'RELIANCE.NS', 'TCS': 'TCS.NS', 'HDFCBANK': 'HDFCBANK.NS', 'INFY': 'INFY.NS', 'ICICIBANK': 'ICICIBANK.NS', # Expanded Mappings 'RELIANCE-INDUSTRIES': 'RELIANCE.NS', 'HDFC': 'HDFC.NS', 'ONGC': 'ONGC.NS', 'INDIAN-OIL-CORPORATION': 'IOC.NS', 'ADANI-GROUP': 'ADANIENT.NS', # Using Adani Enterprises as a representative 'HERO-MOTOCORP': 'HEROMOTOCO.NS', 'ASIAN-PAINTS': 'ASIANPAINT.NS', 'EICHER-MOTORS': 'EICHERMOT.NS', 'ITC': 'ITC.NS', 'TATA-STEEL': 'TATASTEEL.NS', 'SHRIRAM-TRANSPORT-FINANCE': 'SHRIRAMFIN.NS', 'DR-REDDYS-LABORATORIES': 'DRREDDY.NS', 'INFOSYS': 'INFY.NS', 'SUN-PHARMA': 'SUNPHARMA.NS', 'TATA-CONSULTANCY-SERVICES': 'TCS.NS', 'MARUTI-SUZUKI': 'MARUTI.NS', 'HCL-TECHNOLOGIES': 'HCLTECH.NS', 'COAL-INDIA': 'COALINDIA.NS', 'LTI-MINDTREE': 'MINDTREE.NS', 'HDFC-LIFE': 'HDFCLIFE.NS', 'BAJAJ-AUTO': 'BAJAJ-AUTO.NS', 'BRITANNIA-INDUSTRIES': 'BRITANNIA.NS', 'HINDALCO-INDUSTRIES': 'HINDALCO.NS', 'LARSEN-AND-TOUBRO': 'LT.NS', 'TATA-CONSUMER-PRODUCTS': 'TATACONSUM.NS', 'WIPRO': 'WIPRO.NS', 'TITAN': 'TITAN.NS', 'BAJAJ-FINANCE': 'BAJFINANCE.NS', 'JSW-STEEL': 'JSWSTEEL.NS', 'ICICI-BANK': 'ICICIBANK.NS', 'INDUSIND-BANK': 'INDUSINDBK.NS', 'BHARTI-AIRTEL': 'BHARTIARTL.NS', 'DIVIS-LABORATORIES': 'DIVISLAB.NS', 'SBI-LIFE-INSURANCE': 'SBILIFE.NS', 'BAJAJ-FINSERV': 'BAJAJFINSV.NS', 'CIPLA': 'CIPLA.NS', 'GRASIM-INDUSTRIES': 'GRASIM.NS', 'HINDUSTAN-UNILEVER': 'HINDUNILVR.NS', 'MAHINDRA-AND-MAHINDRA': 'M&M.NS', 'TATA-MOTORS': 'TATAMOTORS.NS', 'APOLLO-HOSPITALS-ENTERPRISES': 'APOLLOHOSP.NS', 'SBI': 'SBIN.NS', 'KOTAK-MAHINDRA-BANK': 'KOTAKBANK.NS', 'POWER-GRID-CORPORATION-OF-INDIA': 'POWERGRID.NS', 'AXIS-BANK': 'AXISBANK.NS', 'NTPC': 'NTPC.NS', 'TECH-MAHINDRA': 'TECHM.NS', 'ADANI-PORTS': 'ADANIPORTS.NS', 'ULTRATECH-CEMENT': 'ULTRACEMCO.NS', 'NESTLE': 'NESTLE.NS', 'BHARAT-PETROLEUM': 'BPCL.NS' } def get_stock_data(self, symbol: str): """Fetch stock data from Yahoo Finance""" try: # Convert symbol to Yahoo Finance format yf_symbol = self.stock_mapping.get(symbol.upper(), f"{symbol.upper()}.NS") stock = yf.Ticker(yf_symbol) info = stock.info # Check if the response is valid if not info or 'currentPrice' not in info: return {"error": f"Stock symbol '{symbol}' not found or invalid. Please verify the symbol."} return { "current_price": info.get("currentPrice", "N/A"), "previous_close": info.get("previousClose", "N/A"), "day_high": info.get("dayHigh", "N/A"), "day_low": info.get("dayLow", "N/A"), "last_updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } except Exception as e: return {"error": f"Could not fetch stock data: {str(e)}"} class RAGPipeline: def __init__(self, model_path): self.tokenizer = tokenizer # Use already loaded tokenizer self.model = model # Use already loaded model self.stock_retriever = StockDataRetriever() self.encoder = SentenceTransformer('all-MiniLM-L6-v2') # Expanded and more flexible knowledge base self.knowledge_base = [ "stock price of", "current price", "stock performance", "today's stock price", "stock data for", "price of stock" ] self.knowledge_embeddings = self.encoder.encode(self.knowledge_base) # Predefined stock symbols for easier matching self.stock_symbols = list(self.stock_retriever.stock_mapping.keys()) def _extract_stock_symbol(self, query): # Try to find a stock symbol in the query query_upper = query.upper() for symbol in self.stock_symbols: if symbol in query_upper: return symbol # Fallback: try to extract the last word if it looks like a symbol words = query.split() if words and len(words[-1]) > 1: return words[-1].upper() return None def _is_price_query(self, query): query_embedding = self.encoder.encode([query.lower()]) similarities = cosine_similarity(query_embedding, self.knowledge_embeddings)[0] # Lower the threshold and check if any similarity is significant return max(similarities) > 0.5 def _format_stock_data(self, stock_data): """Format stock data into a readable string""" if 'error' in stock_data: return stock_data['error'] return ( f"Stock Data:\n" f"Current Price: ₹{stock_data['current_price']}\n" f"Previous Close: ₹{stock_data['previous_close']}\n" f"Day's High: ₹{stock_data['day_high']}\n" f"Day's Low: ₹{stock_data['day_low']}\n" f"Last Updated: {stock_data['last_updated']}" ) def generate_response(self, query): # Check if the query is related to stock prices stock_context = "" if self._is_price_query(query): # Extract stock symbol symbol = self._extract_stock_symbol(query) if symbol: # Retrieve stock data stock_data = self.stock_retriever.get_stock_data(symbol) stock_context = self._format_stock_data(stock_data) else: stock_context = "No specific stock symbol could be identified." # Prepare input for the model with stock context full_prompt = ( f"Context: {stock_context}\n\n" f"Question: {query}\n" "Answer:" ) # Generate response using the fine-tuned model inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.model.device) with torch.no_grad(): outputs = self.model.generate(inputs["input_ids"], max_length=500) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Initialize the pipeline once try: pipeline = RAGPipeline(model_name) print("RAG Pipeline initialized successfully") except Exception as e: print(f"Error initializing RAG Pipeline: {e}") raise # Chatbot Interface function def chat(message, history): history = history or [] try: response = pipeline.generate_response(message) history.append((message, response)) except Exception as e: history.append((message, f"Error generating response: {str(e)}")) return history, "" # Define the Gradio interface def create_interface(): with gr.Blocks(css=css) as iface: gr.HTML("

Indian Stock Market Assistant

") gr.HTML("

Ask me about Indian stock prices or any general questions.

") chatbot = gr.Chatbot(height=600, elem_id="chatbot") txt = gr.Textbox( placeholder="Type your question here (e.g., 'What is the current price of RELIANCE?')", show_label=False ) txt.submit(chat, [txt, chatbot], [chatbot, txt]) gr.HTML("""

This chatbot provides real-time Indian stock market data and can answer general questions.

Examples: "What's the current price of TCS?", "How is HDFC performing today?", "Tell me about RELIANCE stock"

""") return iface # Create and launch the interface try: iface = create_interface() print("Interface created successfully") except Exception as e: print(f"Error creating interface: {e}") raise # For Hugging Face Spaces deployment if __name__ == "__main__": iface.launch()