gauri-sharan commited on
Commit
3607a5e
·
verified ·
1 Parent(s): 95ade1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -71
app.py CHANGED
@@ -1,83 +1,45 @@
1
  import yfinance as yf
2
  import gradio as gr
3
- import chatgroq # Import Chatgroq API client
4
  from langchain.chains import LLMChain
5
- from langchain.memory import ConversationBufferMemory
6
  from langchain.prompts import PromptTemplate
 
7
 
8
- # Initialize memory for conversation
9
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
10
 
11
- # Define the prompt template for stock analysis
12
- prompt = """
13
- You are a stock market analyst. The user is asking for stock analysis.
14
- Here are the details of the stock:
15
- - Name: {stock_name}
16
- - Price: {stock_price}
17
- - Sector: {sector}
18
- - P/E Ratio: {pe_ratio}
19
 
20
- Analyze the stock performance and provide a brief summary and advice to the user.
 
 
 
21
  """
22
-
23
- # Function to fetch stock data from Yahoo Finance
24
- def get_stock_data(ticker):
25
- stock = yf.Ticker(ticker)
26
- stock_info = stock.info
27
- stock_price = stock_info.get('currentPrice', 'N/A')
28
- stock_name = stock_info.get('shortName', 'N/A')
29
- stock_sector = stock_info.get('sector', 'N/A')
30
- stock_pe_ratio = stock_info.get('trailingPE', 'N/A')
31
- return {
32
- 'stock_name': stock_name,
33
- 'stock_price': stock_price,
34
- 'sector': stock_sector,
35
- 'pe_ratio': stock_pe_ratio
36
- }
37
-
38
- # Function to generate stock analysis based on the conversation context
39
- def stock_analysis(ticker, user_query=""):
40
- # Get API key from environment (no need to pass manually)
41
- api_key = os.getenv("CHATGROQ_API_KEY") # Ensure Hugging Face secret is set
42
 
43
- if not api_key:
44
- return "API key is required for Chatgroq."
45
-
46
- # Initialize Chatgroq model with the API key
47
- chatgroq_model = chatgroq.ChatGroqModel(api_key)
48
 
49
- # Set the stock data in the prompt template
50
- stock_data = get_stock_data(ticker)
51
- prompt_input = {
52
- 'stock_name': stock_data['stock_name'],
53
- 'stock_price': stock_data['stock_price'],
54
- 'sector': stock_data['sector'],
55
- 'pe_ratio': stock_data['pe_ratio']
56
- }
57
-
58
- # Create LLM chain with Chatgroq model and prompt
59
- chain = LLMChain(llm=chatgroq_model, prompt=PromptTemplate.from_template(prompt), memory=memory)
60
-
61
- # Get analysis from Chatgroq-based agent
62
- analysis = chain.run(input=prompt_input)
63
-
64
- # Generate advice if there is a user query
65
- if user_query:
66
- return stock_data, analysis + "\n" + "Advice: " + user_query
67
- return stock_data, analysis
68
 
69
- # Gradio interface
70
- iface = gr.Interface(
71
- fn=stock_analysis,
72
- inputs=[
73
- gr.Textbox(label="Stock Ticker", placeholder="Enter stock ticker (e.g., AAPL)"),
74
- gr.Textbox(label="User Query", placeholder="Ask about the stock (optional)", optional=True)
75
- ],
76
- outputs=[
77
- gr.JSON(label="Stock Data"),
78
- gr.Textbox(label="Analysis"),
79
- ]
80
- )
81
 
82
- if __name__ == '__main__':
83
- iface.launch()
 
1
  import yfinance as yf
2
  import gradio as gr
 
3
  from langchain.chains import LLMChain
 
4
  from langchain.prompts import PromptTemplate
5
+ from langchain.llms import LlamaCpp
6
 
7
+ # Define Llama model (you must provide the path to your Llama model)
8
+ llama_model = LlamaCpp(model_path="path_to_your_llama_model.bin")
9
 
10
+ # Set up Langchain LLM
11
+ llm_chain = LLMChain(llm=llama_model)
 
 
 
 
 
 
12
 
13
+ # Define the prompt template for querying the LLM about stock details
14
+ prompt_template = """
15
+ You are a stock market assistant. Please provide detailed information about the stock.
16
+ The stock symbol is: {stock_symbol}.
17
  """
18
+ stock_prompt = PromptTemplate(input_variables=["stock_symbol"], template=prompt_template)
19
+
20
+ # Function to get stock data from Yahoo Finance
21
+ def get_stock_data(stock_symbol):
22
+ stock = yf.Ticker(stock_symbol)
23
+ stock_info = stock.info # Get company info
24
+ stock_price = stock_info['currentPrice'] # Get current stock price
25
+ company_name = stock_info.get('shortName', 'N/A')
26
+ sector = stock_info.get('sector', 'N/A')
27
+ market_cap = stock_info.get('marketCap', 'N/A')
28
+
29
+ # Generate a response using Langchain LLM
30
+ stock_details = llm_chain.run(stock_prompt.format(stock_symbol=stock_symbol))
 
 
 
 
 
 
 
31
 
32
+ return f"Company: {company_name}\nSector: {sector}\nMarket Cap: {market_cap}\nStock Price: ${stock_price}\n\nDetailed Information: {stock_details}"
 
 
 
 
33
 
34
+ # Gradio Interface
35
+ def stock_market_agent(stock_symbol):
36
+ return get_stock_data(stock_symbol)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Set up Gradio UI
39
+ iface = gr.Interface(fn=stock_market_agent,
40
+ inputs=gr.Textbox(label="Enter Stock Symbol (e.g., AAPL)"),
41
+ outputs="text",
42
+ title="Stock Market Agent",
43
+ description="Get real-time stock prices and company information from Yahoo Finance.")
 
 
 
 
 
 
44
 
45
+ iface.launch()