finbot / app.py
Akshit-77's picture
Update app.py
187a6d3 verified
import gradio as gr
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import yfinance as yf
from datetime import datetime
from huggingface_hub import login
# Get Hugging Face token from environment variables
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
# Login to Hugging Face
login(token=hf_token)
print("Successfully logged in to Hugging Face")
else:
print("WARNING: HF_TOKEN not found in environment variables. You may face access issues for gated models.")
# Load the Chatbot Model and Tokenizer
model_name = "Akshit-77/llama-3.2-3b-chatbot"
# Try to use local cache if already downloaded, otherwise load with auth token
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
except Exception as e:
print(f"Error loading tokenizer: {e}")
raise
# Configure quantization for memory efficiency
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="fp4",
bnb_4bit_compute_dtype="float16"
)
# Load model with quantization and auth token
try:
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
token=hf_token
)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
raise
# CSS for Chatbot UI
css = """
#chatbot {
font-family: Arial, sans-serif;
background-color: #e5ddd5;
}
.message {
padding: 10px 15px;
border-radius: 7.5px;
margin: 5px 0;
max-width: 75%;
position: relative;
}
.user-message {
background: #dcf8c6;
margin-left: auto;
margin-right: 10px;
}
.bot-message {
background: white;
margin-left: 10px;
}
.timestamp {
font-size: 0.7em;
color: #667781;
float: right;
margin-left: 10px;
margin-top: 3px;
}
"""
class StockDataRetriever:
def __init__(self):
self.stock_mapping = {
# Existing mappings
'RELIANCE': 'RELIANCE.NS',
'TCS': 'TCS.NS',
'HDFCBANK': 'HDFCBANK.NS',
'INFY': 'INFY.NS',
'ICICIBANK': 'ICICIBANK.NS',
# Expanded Mappings
'RELIANCE-INDUSTRIES': 'RELIANCE.NS',
'HDFC': 'HDFC.NS',
'ONGC': 'ONGC.NS',
'INDIAN-OIL-CORPORATION': 'IOC.NS',
'ADANI-GROUP': 'ADANIENT.NS', # Using Adani Enterprises as a representative
'HERO-MOTOCORP': 'HEROMOTOCO.NS',
'ASIAN-PAINTS': 'ASIANPAINT.NS',
'EICHER-MOTORS': 'EICHERMOT.NS',
'ITC': 'ITC.NS',
'TATA-STEEL': 'TATASTEEL.NS',
'SHRIRAM-TRANSPORT-FINANCE': 'SHRIRAMFIN.NS',
'DR-REDDYS-LABORATORIES': 'DRREDDY.NS',
'INFOSYS': 'INFY.NS',
'SUN-PHARMA': 'SUNPHARMA.NS',
'TATA-CONSULTANCY-SERVICES': 'TCS.NS',
'MARUTI-SUZUKI': 'MARUTI.NS',
'HCL-TECHNOLOGIES': 'HCLTECH.NS',
'COAL-INDIA': 'COALINDIA.NS',
'LTI-MINDTREE': 'MINDTREE.NS',
'HDFC-LIFE': 'HDFCLIFE.NS',
'BAJAJ-AUTO': 'BAJAJ-AUTO.NS',
'BRITANNIA-INDUSTRIES': 'BRITANNIA.NS',
'HINDALCO-INDUSTRIES': 'HINDALCO.NS',
'LARSEN-AND-TOUBRO': 'LT.NS',
'TATA-CONSUMER-PRODUCTS': 'TATACONSUM.NS',
'WIPRO': 'WIPRO.NS',
'TITAN': 'TITAN.NS',
'BAJAJ-FINANCE': 'BAJFINANCE.NS',
'JSW-STEEL': 'JSWSTEEL.NS',
'ICICI-BANK': 'ICICIBANK.NS',
'INDUSIND-BANK': 'INDUSINDBK.NS',
'BHARTI-AIRTEL': 'BHARTIARTL.NS',
'DIVIS-LABORATORIES': 'DIVISLAB.NS',
'SBI-LIFE-INSURANCE': 'SBILIFE.NS',
'BAJAJ-FINSERV': 'BAJAJFINSV.NS',
'CIPLA': 'CIPLA.NS',
'GRASIM-INDUSTRIES': 'GRASIM.NS',
'HINDUSTAN-UNILEVER': 'HINDUNILVR.NS',
'MAHINDRA-AND-MAHINDRA': 'M&M.NS',
'TATA-MOTORS': 'TATAMOTORS.NS',
'APOLLO-HOSPITALS-ENTERPRISES': 'APOLLOHOSP.NS',
'SBI': 'SBIN.NS',
'KOTAK-MAHINDRA-BANK': 'KOTAKBANK.NS',
'POWER-GRID-CORPORATION-OF-INDIA': 'POWERGRID.NS',
'AXIS-BANK': 'AXISBANK.NS',
'NTPC': 'NTPC.NS',
'TECH-MAHINDRA': 'TECHM.NS',
'ADANI-PORTS': 'ADANIPORTS.NS',
'ULTRATECH-CEMENT': 'ULTRACEMCO.NS',
'NESTLE': 'NESTLE.NS',
'BHARAT-PETROLEUM': 'BPCL.NS'
}
def get_stock_data(self, symbol: str):
"""Fetch stock data from Yahoo Finance"""
try:
# Convert symbol to Yahoo Finance format
yf_symbol = self.stock_mapping.get(symbol.upper(), f"{symbol.upper()}.NS")
stock = yf.Ticker(yf_symbol)
info = stock.info
# Check if the response is valid
if not info or 'currentPrice' not in info:
return {"error": f"Stock symbol '{symbol}' not found or invalid. Please verify the symbol."}
return {
"current_price": info.get("currentPrice", "N/A"),
"previous_close": info.get("previousClose", "N/A"),
"day_high": info.get("dayHigh", "N/A"),
"day_low": info.get("dayLow", "N/A"),
"last_updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
except Exception as e:
return {"error": f"Could not fetch stock data: {str(e)}"}
class RAGPipeline:
def __init__(self, model_path):
self.tokenizer = tokenizer # Use already loaded tokenizer
self.model = model # Use already loaded model
self.stock_retriever = StockDataRetriever()
self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
# Expanded and more flexible knowledge base
self.knowledge_base = [
"stock price of",
"current price",
"stock performance",
"today's stock price",
"stock data for",
"price of stock"
]
self.knowledge_embeddings = self.encoder.encode(self.knowledge_base)
# Predefined stock symbols for easier matching
self.stock_symbols = list(self.stock_retriever.stock_mapping.keys())
def _extract_stock_symbol(self, query):
# Try to find a stock symbol in the query
query_upper = query.upper()
for symbol in self.stock_symbols:
if symbol in query_upper:
return symbol
# Fallback: try to extract the last word if it looks like a symbol
words = query.split()
if words and len(words[-1]) > 1:
return words[-1].upper()
return None
def _is_price_query(self, query):
query_embedding = self.encoder.encode([query.lower()])
similarities = cosine_similarity(query_embedding, self.knowledge_embeddings)[0]
# Lower the threshold and check if any similarity is significant
return max(similarities) > 0.5
def _format_stock_data(self, stock_data):
"""Format stock data into a readable string"""
if 'error' in stock_data:
return stock_data['error']
return (
f"Stock Data:\n"
f"Current Price: ₹{stock_data['current_price']}\n"
f"Previous Close: ₹{stock_data['previous_close']}\n"
f"Day's High: ₹{stock_data['day_high']}\n"
f"Day's Low: ₹{stock_data['day_low']}\n"
f"Last Updated: {stock_data['last_updated']}"
)
def generate_response(self, query):
# Check if the query is related to stock prices
stock_context = ""
if self._is_price_query(query):
# Extract stock symbol
symbol = self._extract_stock_symbol(query)
if symbol:
# Retrieve stock data
stock_data = self.stock_retriever.get_stock_data(symbol)
stock_context = self._format_stock_data(stock_data)
else:
stock_context = "No specific stock symbol could be identified."
# Prepare input for the model with stock context
full_prompt = (
f"Context: {stock_context}\n\n"
f"Question: {query}\n"
"Answer:"
)
# Generate response using the fine-tuned model
inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(inputs["input_ids"], max_length=500)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Initialize the pipeline once
try:
pipeline = RAGPipeline(model_name)
print("RAG Pipeline initialized successfully")
except Exception as e:
print(f"Error initializing RAG Pipeline: {e}")
raise
# Chatbot Interface function
def chat(message, history):
history = history or []
try:
response = pipeline.generate_response(message)
history.append((message, response))
except Exception as e:
history.append((message, f"Error generating response: {str(e)}"))
return history, ""
# Define the Gradio interface
def create_interface():
with gr.Blocks(css=css) as iface:
gr.HTML("<h1>Indian Stock Market Assistant</h1>")
gr.HTML("<p>Ask me about Indian stock prices or any general questions.</p>")
chatbot = gr.Chatbot(height=600, elem_id="chatbot")
txt = gr.Textbox(
placeholder="Type your question here (e.g., 'What is the current price of RELIANCE?')",
show_label=False
)
txt.submit(chat, [txt, chatbot], [chatbot, txt])
gr.HTML("""
<div style="text-align: center; margin-top: 20px; padding: 10px; background-color: #f0f0f0; border-radius: 5px;">
<p>This chatbot provides real-time Indian stock market data and can answer general questions.</p>
<p>Examples: "What's the current price of TCS?", "How is HDFC performing today?", "Tell me about RELIANCE stock"</p>
</div>
""")
return iface
# Create and launch the interface
try:
iface = create_interface()
print("Interface created successfully")
except Exception as e:
print(f"Error creating interface: {e}")
raise
# For Hugging Face Spaces deployment
if __name__ == "__main__":
iface.launch()