|
import gradio as gr |
|
import torch |
|
import os |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import yfinance as yf |
|
from datetime import datetime |
|
from huggingface_hub import login |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if hf_token: |
|
|
|
login(token=hf_token) |
|
print("Successfully logged in to Hugging Face") |
|
else: |
|
print("WARNING: HF_TOKEN not found in environment variables. You may face access issues for gated models.") |
|
|
|
|
|
model_name = "Akshit-77/llama-3.2-3b-chatbot" |
|
|
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) |
|
except Exception as e: |
|
print(f"Error loading tokenizer: {e}") |
|
raise |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="fp4", |
|
bnb_4bit_compute_dtype="float16" |
|
) |
|
|
|
|
|
try: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
token=hf_token |
|
) |
|
print("Model loaded successfully") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise |
|
|
|
|
|
css = """ |
|
#chatbot { |
|
font-family: Arial, sans-serif; |
|
background-color: #e5ddd5; |
|
} |
|
.message { |
|
padding: 10px 15px; |
|
border-radius: 7.5px; |
|
margin: 5px 0; |
|
max-width: 75%; |
|
position: relative; |
|
} |
|
.user-message { |
|
background: #dcf8c6; |
|
margin-left: auto; |
|
margin-right: 10px; |
|
} |
|
.bot-message { |
|
background: white; |
|
margin-left: 10px; |
|
} |
|
.timestamp { |
|
font-size: 0.7em; |
|
color: #667781; |
|
float: right; |
|
margin-left: 10px; |
|
margin-top: 3px; |
|
} |
|
""" |
|
|
|
class StockDataRetriever: |
|
def __init__(self): |
|
self.stock_mapping = { |
|
|
|
'RELIANCE': 'RELIANCE.NS', |
|
'TCS': 'TCS.NS', |
|
'HDFCBANK': 'HDFCBANK.NS', |
|
'INFY': 'INFY.NS', |
|
'ICICIBANK': 'ICICIBANK.NS', |
|
|
|
|
|
'RELIANCE-INDUSTRIES': 'RELIANCE.NS', |
|
'HDFC': 'HDFC.NS', |
|
'ONGC': 'ONGC.NS', |
|
'INDIAN-OIL-CORPORATION': 'IOC.NS', |
|
'ADANI-GROUP': 'ADANIENT.NS', |
|
|
|
'HERO-MOTOCORP': 'HEROMOTOCO.NS', |
|
'ASIAN-PAINTS': 'ASIANPAINT.NS', |
|
'EICHER-MOTORS': 'EICHERMOT.NS', |
|
'ITC': 'ITC.NS', |
|
'TATA-STEEL': 'TATASTEEL.NS', |
|
|
|
'SHRIRAM-TRANSPORT-FINANCE': 'SHRIRAMFIN.NS', |
|
'DR-REDDYS-LABORATORIES': 'DRREDDY.NS', |
|
'INFOSYS': 'INFY.NS', |
|
'SUN-PHARMA': 'SUNPHARMA.NS', |
|
'TATA-CONSULTANCY-SERVICES': 'TCS.NS', |
|
|
|
'MARUTI-SUZUKI': 'MARUTI.NS', |
|
'HCL-TECHNOLOGIES': 'HCLTECH.NS', |
|
'COAL-INDIA': 'COALINDIA.NS', |
|
'LTI-MINDTREE': 'MINDTREE.NS', |
|
'HDFC-LIFE': 'HDFCLIFE.NS', |
|
|
|
'BAJAJ-AUTO': 'BAJAJ-AUTO.NS', |
|
'BRITANNIA-INDUSTRIES': 'BRITANNIA.NS', |
|
'HINDALCO-INDUSTRIES': 'HINDALCO.NS', |
|
'LARSEN-AND-TOUBRO': 'LT.NS', |
|
'TATA-CONSUMER-PRODUCTS': 'TATACONSUM.NS', |
|
|
|
'WIPRO': 'WIPRO.NS', |
|
'TITAN': 'TITAN.NS', |
|
'BAJAJ-FINANCE': 'BAJFINANCE.NS', |
|
'JSW-STEEL': 'JSWSTEEL.NS', |
|
'ICICI-BANK': 'ICICIBANK.NS', |
|
|
|
'INDUSIND-BANK': 'INDUSINDBK.NS', |
|
'BHARTI-AIRTEL': 'BHARTIARTL.NS', |
|
'DIVIS-LABORATORIES': 'DIVISLAB.NS', |
|
'SBI-LIFE-INSURANCE': 'SBILIFE.NS', |
|
'BAJAJ-FINSERV': 'BAJAJFINSV.NS', |
|
|
|
'CIPLA': 'CIPLA.NS', |
|
'GRASIM-INDUSTRIES': 'GRASIM.NS', |
|
'HINDUSTAN-UNILEVER': 'HINDUNILVR.NS', |
|
'MAHINDRA-AND-MAHINDRA': 'M&M.NS', |
|
'TATA-MOTORS': 'TATAMOTORS.NS', |
|
|
|
'APOLLO-HOSPITALS-ENTERPRISES': 'APOLLOHOSP.NS', |
|
'SBI': 'SBIN.NS', |
|
'KOTAK-MAHINDRA-BANK': 'KOTAKBANK.NS', |
|
'POWER-GRID-CORPORATION-OF-INDIA': 'POWERGRID.NS', |
|
'AXIS-BANK': 'AXISBANK.NS', |
|
|
|
'NTPC': 'NTPC.NS', |
|
'TECH-MAHINDRA': 'TECHM.NS', |
|
'ADANI-PORTS': 'ADANIPORTS.NS', |
|
'ULTRATECH-CEMENT': 'ULTRACEMCO.NS', |
|
'NESTLE': 'NESTLE.NS', |
|
'BHARAT-PETROLEUM': 'BPCL.NS' |
|
} |
|
|
|
def get_stock_data(self, symbol: str): |
|
"""Fetch stock data from Yahoo Finance""" |
|
try: |
|
|
|
yf_symbol = self.stock_mapping.get(symbol.upper(), f"{symbol.upper()}.NS") |
|
stock = yf.Ticker(yf_symbol) |
|
info = stock.info |
|
|
|
|
|
if not info or 'currentPrice' not in info: |
|
return {"error": f"Stock symbol '{symbol}' not found or invalid. Please verify the symbol."} |
|
|
|
return { |
|
"current_price": info.get("currentPrice", "N/A"), |
|
"previous_close": info.get("previousClose", "N/A"), |
|
"day_high": info.get("dayHigh", "N/A"), |
|
"day_low": info.get("dayLow", "N/A"), |
|
"last_updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
} |
|
except Exception as e: |
|
return {"error": f"Could not fetch stock data: {str(e)}"} |
|
|
|
|
|
class RAGPipeline: |
|
def __init__(self, model_path): |
|
self.tokenizer = tokenizer |
|
self.model = model |
|
self.stock_retriever = StockDataRetriever() |
|
self.encoder = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
self.knowledge_base = [ |
|
"stock price of", |
|
"current price", |
|
"stock performance", |
|
"today's stock price", |
|
"stock data for", |
|
"price of stock" |
|
] |
|
self.knowledge_embeddings = self.encoder.encode(self.knowledge_base) |
|
|
|
|
|
self.stock_symbols = list(self.stock_retriever.stock_mapping.keys()) |
|
|
|
def _extract_stock_symbol(self, query): |
|
|
|
query_upper = query.upper() |
|
for symbol in self.stock_symbols: |
|
if symbol in query_upper: |
|
return symbol |
|
|
|
|
|
words = query.split() |
|
if words and len(words[-1]) > 1: |
|
return words[-1].upper() |
|
|
|
return None |
|
|
|
def _is_price_query(self, query): |
|
query_embedding = self.encoder.encode([query.lower()]) |
|
similarities = cosine_similarity(query_embedding, self.knowledge_embeddings)[0] |
|
|
|
|
|
return max(similarities) > 0.5 |
|
|
|
def _format_stock_data(self, stock_data): |
|
"""Format stock data into a readable string""" |
|
if 'error' in stock_data: |
|
return stock_data['error'] |
|
|
|
return ( |
|
f"Stock Data:\n" |
|
f"Current Price: ₹{stock_data['current_price']}\n" |
|
f"Previous Close: ₹{stock_data['previous_close']}\n" |
|
f"Day's High: ₹{stock_data['day_high']}\n" |
|
f"Day's Low: ₹{stock_data['day_low']}\n" |
|
f"Last Updated: {stock_data['last_updated']}" |
|
) |
|
|
|
def generate_response(self, query): |
|
|
|
stock_context = "" |
|
if self._is_price_query(query): |
|
|
|
symbol = self._extract_stock_symbol(query) |
|
|
|
if symbol: |
|
|
|
stock_data = self.stock_retriever.get_stock_data(symbol) |
|
stock_context = self._format_stock_data(stock_data) |
|
else: |
|
stock_context = "No specific stock symbol could be identified." |
|
|
|
|
|
full_prompt = ( |
|
f"Context: {stock_context}\n\n" |
|
f"Question: {query}\n" |
|
"Answer:" |
|
) |
|
|
|
|
|
inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.model.device) |
|
with torch.no_grad(): |
|
outputs = self.model.generate(inputs["input_ids"], max_length=500) |
|
|
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
try: |
|
pipeline = RAGPipeline(model_name) |
|
print("RAG Pipeline initialized successfully") |
|
except Exception as e: |
|
print(f"Error initializing RAG Pipeline: {e}") |
|
raise |
|
|
|
|
|
def chat(message, history): |
|
history = history or [] |
|
try: |
|
response = pipeline.generate_response(message) |
|
history.append((message, response)) |
|
except Exception as e: |
|
history.append((message, f"Error generating response: {str(e)}")) |
|
return history, "" |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(css=css) as iface: |
|
gr.HTML("<h1>Indian Stock Market Assistant</h1>") |
|
gr.HTML("<p>Ask me about Indian stock prices or any general questions.</p>") |
|
|
|
chatbot = gr.Chatbot(height=600, elem_id="chatbot") |
|
txt = gr.Textbox( |
|
placeholder="Type your question here (e.g., 'What is the current price of RELIANCE?')", |
|
show_label=False |
|
) |
|
|
|
txt.submit(chat, [txt, chatbot], [chatbot, txt]) |
|
|
|
gr.HTML(""" |
|
<div style="text-align: center; margin-top: 20px; padding: 10px; background-color: #f0f0f0; border-radius: 5px;"> |
|
<p>This chatbot provides real-time Indian stock market data and can answer general questions.</p> |
|
<p>Examples: "What's the current price of TCS?", "How is HDFC performing today?", "Tell me about RELIANCE stock"</p> |
|
</div> |
|
""") |
|
|
|
return iface |
|
|
|
|
|
try: |
|
iface = create_interface() |
|
print("Interface created successfully") |
|
except Exception as e: |
|
print(f"Error creating interface: {e}") |
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |