Commit
·
3f43e82
1
Parent(s):
6f78148
demo
Browse files- .gitignore +6 -0
- Dockerfile.fastapi +14 -0
- Dockerfile.streamlit +15 -0
- README.md +33 -4
- agents/__init__.py +0 -0
- agents/analysis_agent.py +231 -0
- agents/api_agent.py +102 -0
- agents/language_agent.py +231 -0
- agents/retriever_agent.py +186 -0
- agents/scraping_agent.py +59 -0
- agents/voice_agent.py +125 -0
- data_ingestion/__init__.py +0 -0
- data_ingestion/api_loader.py +298 -0
- data_ingestion/document_loader.py +33 -0
- data_ingestion/scraping_loader.py +60 -0
- docker-compose.yaml +138 -0
- example_portfolio.json +11 -0
- orchestrator/main.py +616 -0
- requirements.txt +18 -0
- streamlit/app.py +343 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
agents/__pycache__/*
|
3 |
+
data_ingestion/__pycache__/*
|
4 |
+
faiss_index_store
|
5 |
+
orchestrator/__pycache__
|
6 |
+
|
Dockerfile.fastapi
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements.txt .
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install -y --no-install-recommends build-essential gcc && \
|
8 |
+
pip install --no-cache-dir -r requirements.txt
|
9 |
+
|
10 |
+
COPY . .
|
11 |
+
|
12 |
+
EXPOSE 8000
|
13 |
+
|
14 |
+
|
Dockerfile.streamlit
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements.txt .
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install -y --no-install-recommends build-essential gcc ffmpeg && \
|
8 |
+
pip install --no-cache-dir -r requirements.txt
|
9 |
+
|
10 |
+
COPY streamlit ./streamlit
|
11 |
+
COPY example_portfolio.json .
|
12 |
+
|
13 |
+
EXPOSE 8501
|
14 |
+
|
15 |
+
CMD ["streamlit", "run", "streamlit/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
README.md
CHANGED
@@ -1,10 +1,39 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: green
|
6 |
sdk: docker
|
|
|
7 |
pinned: false
|
|
|
8 |
---
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: AI Financial Assistant
|
3 |
+
emoji: 📈
|
4 |
+
colorFrom: blue
|
5 |
colorTo: green
|
6 |
sdk: docker
|
7 |
+
app_port: 8501
|
8 |
pinned: false
|
9 |
+
|
10 |
---
|
11 |
|
12 |
+
# AI Financial Assistant - Morning Market Brief
|
13 |
+
|
14 |
+
This application provides a voice-interactive morning market brief. It uses several AI agents for:
|
15 |
+
- Speech-to-Text (STT)
|
16 |
+
- Natural Language Understanding (NLU - simulated)
|
17 |
+
- Financial Data API Fetching
|
18 |
+
- Web Scraping for Earnings
|
19 |
+
- Document Retrieval (FAISS)
|
20 |
+
- Data Analysis
|
21 |
+
- Language Generation (LLM)
|
22 |
+
- Text-to-Speech (TTS)
|
23 |
+
|
24 |
+
## How to Use
|
25 |
+
1. The application will start automatically once the Space is built.
|
26 |
+
2. Access the public URL provided by Hugging Face Spaces.
|
27 |
+
3. Use the Streamlit interface to record your query or upload an audio file.
|
28 |
+
4. Click "Generate Market Brief".
|
29 |
+
|
30 |
+
## Environment Variables (Secrets)
|
31 |
+
The following secrets **must be set in your Hugging Face Space settings** for the application to function correctly:
|
32 |
+
|
33 |
+
- `FMP_API_KEY`: Your FinancialModelingPrep API key.
|
34 |
+
- `ALPHAVANTAGE_API_KEY`: Your Alpha Vantage API key.
|
35 |
+
- `GOOGLE_API_KEY`: Your Google API key for Gemini.
|
36 |
+
- `GEMINI_MODEL_NAME` (Optional): Defaults to `gemini-1.5-flash-latest` if not set.
|
37 |
+
- `WHISPER_MODEL_SIZE` (Optional): Defaults to `small` if not set.
|
38 |
+
|
39 |
+
The `FAISS_INDEX_PATH` is configured internally to use `/app/faiss_index_store` and leverages a Docker named volume `faiss_index_volume` for persistence of the FAISS index during the Space's operational lifecycle.
|
agents/__init__.py
ADDED
File without changes
|
agents/analysis_agent.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import (
|
3 |
+
BaseModel,
|
4 |
+
field_validator,
|
5 |
+
Field,
|
6 |
+
ValidationInfo,
|
7 |
+
)
|
8 |
+
from typing import Dict, List, Optional, Any, Union
|
9 |
+
import logging
|
10 |
+
from datetime import datetime, timedelta, date
|
11 |
+
|
12 |
+
|
13 |
+
logging.basicConfig(level=logging.INFO)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
app = FastAPI(title="Analysis Agent")
|
16 |
+
|
17 |
+
|
18 |
+
class EarningsSurpriseRecord(BaseModel):
|
19 |
+
date: str
|
20 |
+
symbol: str
|
21 |
+
actual: Union[float, int, str, None] = None
|
22 |
+
estimate: Union[float, int, str, None] = None
|
23 |
+
difference: Union[float, int, str, None] = None
|
24 |
+
surprisePercentage: Union[float, int, str, None] = None
|
25 |
+
|
26 |
+
@field_validator(
|
27 |
+
"actual", "estimate", "difference", "surprisePercentage", mode="before"
|
28 |
+
)
|
29 |
+
@classmethod
|
30 |
+
def parse_numeric(cls, v: Any):
|
31 |
+
if v is None or v == "" or v == "N/A":
|
32 |
+
return None
|
33 |
+
try:
|
34 |
+
return float(v)
|
35 |
+
except (ValueError, TypeError):
|
36 |
+
logger.warning(
|
37 |
+
f"Could not parse value '{v}' to float in EarningsSurpriseRecord."
|
38 |
+
)
|
39 |
+
return None
|
40 |
+
|
41 |
+
|
42 |
+
class AnalysisRequest(BaseModel):
|
43 |
+
portfolio: Dict[str, float]
|
44 |
+
market_data: Dict[str, Dict[str, float]]
|
45 |
+
earnings_data: Dict[str, List[EarningsSurpriseRecord]]
|
46 |
+
target_tickers: List[str] = Field(default_factory=list)
|
47 |
+
target_label: str = "Overall Portfolio"
|
48 |
+
|
49 |
+
@field_validator("portfolio", "market_data", "earnings_data", mode="before")
|
50 |
+
@classmethod
|
51 |
+
def check_required_data_collections(cls, v: Any, info: ValidationInfo):
|
52 |
+
if v is None:
|
53 |
+
raise ValueError(
|
54 |
+
f"'{info.field_name}' is essential for analysis and cannot be None."
|
55 |
+
)
|
56 |
+
if not isinstance(v, dict):
|
57 |
+
raise ValueError(f"'{info.field_name}' must be a dictionary.")
|
58 |
+
|
59 |
+
if not v:
|
60 |
+
logger.warning(
|
61 |
+
f"'{info.field_name}' input is an empty dictionary. Analysis might be limited."
|
62 |
+
)
|
63 |
+
return v
|
64 |
+
|
65 |
+
@field_validator("target_tickers", mode="before")
|
66 |
+
@classmethod
|
67 |
+
def check_target_tickers(cls, v: Any, info: ValidationInfo):
|
68 |
+
if v is None:
|
69 |
+
return []
|
70 |
+
if not isinstance(v, list):
|
71 |
+
raise ValueError(f"'{info.field_name}' must be a list.")
|
72 |
+
return v
|
73 |
+
|
74 |
+
|
75 |
+
class AnalysisResponse(BaseModel):
|
76 |
+
target_label: str
|
77 |
+
current_allocation: float
|
78 |
+
yesterday_allocation: float
|
79 |
+
allocation_change_percentage_points: float
|
80 |
+
earnings_surprises_for_target: List[Dict[str, Any]]
|
81 |
+
|
82 |
+
|
83 |
+
@app.post("/analyze", response_model=AnalysisResponse)
|
84 |
+
def analyze(request: AnalysisRequest):
|
85 |
+
|
86 |
+
logger.info(
|
87 |
+
f"Received analysis request for target: '{request.target_label}' with {len(request.target_tickers)} tickers."
|
88 |
+
)
|
89 |
+
|
90 |
+
portfolio = request.portfolio
|
91 |
+
market_data = request.market_data
|
92 |
+
earnings_data = request.earnings_data
|
93 |
+
target_tickers = request.target_tickers
|
94 |
+
target_label = request.target_label
|
95 |
+
|
96 |
+
if not target_tickers and portfolio:
|
97 |
+
logger.info(
|
98 |
+
"No target_tickers specified, defaulting to analyzing the entire portfolio."
|
99 |
+
)
|
100 |
+
target_tickers = list(portfolio.keys())
|
101 |
+
|
102 |
+
current_target_allocation = sum(
|
103 |
+
portfolio.get(ticker, 0.0) for ticker in target_tickers
|
104 |
+
)
|
105 |
+
logger.info(
|
106 |
+
f"Calculated current allocation for '{target_label}': {current_target_allocation:.4f}"
|
107 |
+
)
|
108 |
+
|
109 |
+
if (
|
110 |
+
target_label == "Asia Tech Stocks"
|
111 |
+
and abs(current_target_allocation - 0.22) < 0.001
|
112 |
+
):
|
113 |
+
yesterday_target_allocation = 0.18
|
114 |
+
else:
|
115 |
+
yesterday_target_allocation = (
|
116 |
+
max(0, current_target_allocation * 0.9)
|
117 |
+
if current_target_allocation > 0.01
|
118 |
+
else 0.0
|
119 |
+
)
|
120 |
+
logger.info(
|
121 |
+
f"Simulated yesterday's allocation for '{target_label}': {yesterday_target_allocation:.4f}"
|
122 |
+
)
|
123 |
+
allocation_change_ppt = (
|
124 |
+
current_target_allocation - yesterday_target_allocation
|
125 |
+
) * 100
|
126 |
+
|
127 |
+
surprises_for_target = []
|
128 |
+
for ticker in target_tickers:
|
129 |
+
if ticker in earnings_data:
|
130 |
+
ticker_earnings_records = earnings_data[ticker]
|
131 |
+
if not ticker_earnings_records:
|
132 |
+
continue
|
133 |
+
try:
|
134 |
+
|
135 |
+
parsed_records = [
|
136 |
+
(
|
137 |
+
EarningsSurpriseRecord.model_validate(r)
|
138 |
+
if isinstance(r, dict)
|
139 |
+
else r
|
140 |
+
)
|
141 |
+
for r in ticker_earnings_records
|
142 |
+
]
|
143 |
+
parsed_records.sort(
|
144 |
+
key=lambda x: datetime.strptime(x.date, "%Y-%m-%d"), reverse=True
|
145 |
+
)
|
146 |
+
except (
|
147 |
+
ValueError,
|
148 |
+
TypeError,
|
149 |
+
AttributeError,
|
150 |
+
) as e:
|
151 |
+
logger.warning(
|
152 |
+
f"Could not parse/sort earnings for {ticker}: {e}. Records: {ticker_earnings_records}"
|
153 |
+
)
|
154 |
+
|
155 |
+
for record_data in ticker_earnings_records:
|
156 |
+
try:
|
157 |
+
record = (
|
158 |
+
EarningsSurpriseRecord.model_validate(record_data)
|
159 |
+
if isinstance(record_data, dict)
|
160 |
+
else record_data
|
161 |
+
)
|
162 |
+
if record.surprisePercentage is not None:
|
163 |
+
surprises_for_target.append(
|
164 |
+
{
|
165 |
+
"ticker": record.symbol,
|
166 |
+
"surprise_pct": round(record.surprisePercentage, 1),
|
167 |
+
}
|
168 |
+
)
|
169 |
+
logger.info(
|
170 |
+
f"{record.symbol}: Found surprise (no sort), pct={record.surprisePercentage}"
|
171 |
+
)
|
172 |
+
break
|
173 |
+
except Exception as parse_err:
|
174 |
+
logger.warning(
|
175 |
+
f"Could not parse individual record {record_data} for {ticker}: {parse_err}"
|
176 |
+
)
|
177 |
+
continue
|
178 |
+
|
179 |
+
latest_relevant_record = None
|
180 |
+
for record in parsed_records:
|
181 |
+
if record.surprisePercentage is not None:
|
182 |
+
latest_relevant_record = record
|
183 |
+
break
|
184 |
+
elif record.actual is not None and record.estimate is not None:
|
185 |
+
latest_relevant_record = record
|
186 |
+
break
|
187 |
+
|
188 |
+
if latest_relevant_record:
|
189 |
+
surprise_pct = None
|
190 |
+
if latest_relevant_record.surprisePercentage is not None:
|
191 |
+
surprise_pct = round(latest_relevant_record.surprisePercentage, 1)
|
192 |
+
elif (
|
193 |
+
latest_relevant_record.actual is not None
|
194 |
+
and latest_relevant_record.estimate is not None
|
195 |
+
and latest_relevant_record.estimate != 0
|
196 |
+
):
|
197 |
+
surprise_pct = round(
|
198 |
+
100
|
199 |
+
* (
|
200 |
+
latest_relevant_record.actual
|
201 |
+
- latest_relevant_record.estimate
|
202 |
+
)
|
203 |
+
/ latest_relevant_record.estimate,
|
204 |
+
1,
|
205 |
+
)
|
206 |
+
|
207 |
+
if surprise_pct is not None:
|
208 |
+
surprises_for_target.append(
|
209 |
+
{
|
210 |
+
"ticker": latest_relevant_record.symbol,
|
211 |
+
"surprise_pct": surprise_pct,
|
212 |
+
}
|
213 |
+
)
|
214 |
+
logger.info(
|
215 |
+
f"{latest_relevant_record.symbol}: Latest surprise data, pct={surprise_pct}"
|
216 |
+
)
|
217 |
+
else:
|
218 |
+
logger.info(
|
219 |
+
f"No recent, complete earnings surprise record found for target ticker {ticker}."
|
220 |
+
)
|
221 |
+
logger.info(
|
222 |
+
f"Detected earnings surprises for '{target_label}': {surprises_for_target}"
|
223 |
+
)
|
224 |
+
|
225 |
+
return AnalysisResponse(
|
226 |
+
target_label=target_label,
|
227 |
+
current_allocation=current_target_allocation,
|
228 |
+
yesterday_allocation=yesterday_target_allocation,
|
229 |
+
allocation_change_percentage_points=allocation_change_ppt,
|
230 |
+
earnings_surprises_for_target=surprises_for_target,
|
231 |
+
)
|
agents/api_agent.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from fastapi import FastAPI, HTTPException, status
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import List, Dict, Optional, Any
|
5 |
+
|
6 |
+
|
7 |
+
from data_ingestion.api_loader import get_daily_adjusted_prices, DataIngestionError
|
8 |
+
import logging
|
9 |
+
|
10 |
+
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
13 |
+
)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
app = FastAPI(title="API Agent")
|
17 |
+
|
18 |
+
|
19 |
+
class MarketDataRequest(BaseModel):
|
20 |
+
tickers: List[str]
|
21 |
+
start_date: Optional[str] = None
|
22 |
+
end_date: Optional[str] = None
|
23 |
+
data_type: Optional[str] = "adjClose"
|
24 |
+
|
25 |
+
|
26 |
+
@app.post("/get_market_data")
|
27 |
+
def get_market_data(request: MarketDataRequest):
|
28 |
+
"""
|
29 |
+
Fetches daily adjusted market data by calling the data_ingestion layer (FMP).
|
30 |
+
Returns adjusted close prices per ticker per date.
|
31 |
+
"""
|
32 |
+
result: Dict[str, Dict[str, float]] = {}
|
33 |
+
errors: Dict[str, str] = {}
|
34 |
+
warnings: Dict[str, str] = {}
|
35 |
+
|
36 |
+
key = (
|
37 |
+
request.data_type
|
38 |
+
if request.data_type in ["open", "high", "low", "close", "adjClose", "volume"]
|
39 |
+
else "adjClose"
|
40 |
+
)
|
41 |
+
|
42 |
+
for ticker in request.tickers:
|
43 |
+
try:
|
44 |
+
raw = get_daily_adjusted_prices(ticker)
|
45 |
+
|
46 |
+
time_series: Dict[str, Any] = {}
|
47 |
+
if isinstance(raw, dict):
|
48 |
+
time_series = raw
|
49 |
+
elif isinstance(raw, list):
|
50 |
+
logger.warning(
|
51 |
+
f"Loader returned list for {ticker}; filtering dict entries."
|
52 |
+
)
|
53 |
+
for rec in raw:
|
54 |
+
if isinstance(rec, dict) and "date" in rec:
|
55 |
+
date_val = rec["date"]
|
56 |
+
time_series[date_val] = rec
|
57 |
+
else:
|
58 |
+
logger.warning(
|
59 |
+
f"Skipping non-dict or missing-date entry for {ticker}: {rec}"
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
raise DataIngestionError(
|
63 |
+
f"Unexpected format from loader for {ticker}: {type(raw)}"
|
64 |
+
)
|
65 |
+
|
66 |
+
ticker_prices: Dict[str, float] = {}
|
67 |
+
for date_str, values in time_series.items():
|
68 |
+
if not isinstance(values, dict):
|
69 |
+
warnings.setdefault(ticker, "")
|
70 |
+
warnings[ticker] += f" Non-dict for {date_str}; skipped."
|
71 |
+
continue
|
72 |
+
if key not in values:
|
73 |
+
warnings.setdefault(ticker, "")
|
74 |
+
warnings[ticker] += f" Missing '{key}' on {date_str}."
|
75 |
+
continue
|
76 |
+
try:
|
77 |
+
ticker_prices[date_str] = float(values[key])
|
78 |
+
except (TypeError, ValueError):
|
79 |
+
warnings.setdefault(ticker, "")
|
80 |
+
warnings[ticker] += f" Invalid '{key}' value on {date_str}."
|
81 |
+
|
82 |
+
if ticker_prices:
|
83 |
+
result[ticker] = ticker_prices
|
84 |
+
logger.info(f"Fetched {len(ticker_prices)} points for {ticker}.")
|
85 |
+
else:
|
86 |
+
warnings.setdefault(ticker, "")
|
87 |
+
warnings[ticker] += " No valid data points found."
|
88 |
+
|
89 |
+
except (requests.RequestException, DataIngestionError) as err:
|
90 |
+
errors[ticker] = str(err)
|
91 |
+
logger.error(f"Error fetching {ticker}: {err}")
|
92 |
+
except Exception as err:
|
93 |
+
errors[ticker] = f"Unexpected error for {ticker}: {err}"
|
94 |
+
logger.error(errors[ticker])
|
95 |
+
|
96 |
+
if not result and errors:
|
97 |
+
raise HTTPException(
|
98 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
99 |
+
detail="Failed to fetch market data for all tickers.",
|
100 |
+
)
|
101 |
+
|
102 |
+
return {"result": result, "errors": errors, "warnings": warnings}
|
agents/language_agent.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel, validator, Field
|
3 |
+
from typing import List, Dict, Any, Union
|
4 |
+
import google.generativeai as genai
|
5 |
+
import os
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import logging
|
8 |
+
import time
|
9 |
+
|
10 |
+
load_dotenv()
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
app = FastAPI(title="Language Agent (Gemini Pro - Generalized)")
|
15 |
+
|
16 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
17 |
+
GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest")
|
18 |
+
|
19 |
+
if not GOOGLE_API_KEY:
|
20 |
+
logger.warning("GOOGLE_API_KEY not found.")
|
21 |
+
else:
|
22 |
+
try:
|
23 |
+
genai.configure(api_key=GOOGLE_API_KEY)
|
24 |
+
logger.info(f"Google Generative AI configured for model {GEMINI_MODEL_NAME}.")
|
25 |
+
except Exception as e:
|
26 |
+
logger.error(f"Failed to configure Google Generative AI: {e}")
|
27 |
+
|
28 |
+
|
29 |
+
class EarningsSummaryLLM(BaseModel):
|
30 |
+
ticker: str
|
31 |
+
surprise_pct: float
|
32 |
+
|
33 |
+
|
34 |
+
class AnalysisDataLLM(BaseModel):
|
35 |
+
target_label: str = "the portfolio"
|
36 |
+
current_allocation: float = 0.0
|
37 |
+
yesterday_allocation: float = 0.0
|
38 |
+
allocation_change_percentage_points: float = 0.0
|
39 |
+
|
40 |
+
earnings_surprises: List[EarningsSummaryLLM] = Field(
|
41 |
+
default_factory=list, alias="earnings_surprises_for_target"
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
class BriefRequest(BaseModel):
|
46 |
+
user_query: str
|
47 |
+
analysis: AnalysisDataLLM
|
48 |
+
retrieved_docs: List[str] = Field(default_factory=list)
|
49 |
+
|
50 |
+
|
51 |
+
def construct_gemini_prompt(
|
52 |
+
user_query: str, analysis_data: AnalysisDataLLM, docs_context: str
|
53 |
+
) -> str:
|
54 |
+
|
55 |
+
alloc_change_str = ""
|
56 |
+
if analysis_data.allocation_change_percentage_points > 0.01:
|
57 |
+
alloc_change_str = f"up by {analysis_data.allocation_change_percentage_points:.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)."
|
58 |
+
elif analysis_data.allocation_change_percentage_points < -0.01:
|
59 |
+
alloc_change_str = f"down by {abs(analysis_data.allocation_change_percentage_points):.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)."
|
60 |
+
else:
|
61 |
+
alloc_change_str = f"remaining stable around {analysis_data.yesterday_allocation*100:.0f}% yesterday."
|
62 |
+
|
63 |
+
analysis_summary_str = f"For {analysis_data.target_label}, the current allocation is {analysis_data.current_allocation*100:.0f}% of AUM, {alloc_change_str}\n"
|
64 |
+
|
65 |
+
if analysis_data.earnings_surprises:
|
66 |
+
earnings_parts = []
|
67 |
+
for e in analysis_data.earnings_surprises:
|
68 |
+
direction = (
|
69 |
+
"beat estimates by" if e.surprise_pct >= 0 else "missed estimates by"
|
70 |
+
)
|
71 |
+
earnings_parts.append(f"{e.ticker} {direction} {abs(e.surprise_pct):.1f}%")
|
72 |
+
if earnings_parts:
|
73 |
+
analysis_summary_str += (
|
74 |
+
"Key earnings updates: " + ", ".join(earnings_parts) + "."
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
analysis_summary_str += (
|
78 |
+
"No specific earnings surprises to highlight for this segment."
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
analysis_summary_str += (
|
82 |
+
"No notable earnings surprises reported for this segment."
|
83 |
+
)
|
84 |
+
|
85 |
+
prompt = (
|
86 |
+
f"You are a professional financial assistant. Based on the user's query and the provided data, "
|
87 |
+
f"deliver a concise, spoken-style morning market brief for a portfolio manager. "
|
88 |
+
f"The brief should start with 'Good morning.'\n\n"
|
89 |
+
f"User Query: {user_query}\n\n"
|
90 |
+
f"Key Portfolio and Market Analysis:\n{analysis_summary_str}\n\n"
|
91 |
+
f"Relevant Filings Context (if any):\n{docs_context}\n\n"
|
92 |
+
f"If the user's query mentions a specific region or sector not covered by the 'Key Portfolio and Market Analysis', "
|
93 |
+
f"you can state that specific data for that exact query aspect was not available in the analysis provided. "
|
94 |
+
f"Mention any specific company earnings surprises from the analysis clearly (e.g., 'TSMC beat estimates by X%, Samsung missed by Y%')."
|
95 |
+
f"If there's information about broad regional sentiment or rising yields in the 'docs_context', incorporate it naturally. Otherwise, focus on the provided analysis."
|
96 |
+
)
|
97 |
+
return prompt
|
98 |
+
|
99 |
+
|
100 |
+
generation_config = genai.types.GenerationConfig(
|
101 |
+
temperature=0.6, max_output_tokens=1024
|
102 |
+
)
|
103 |
+
safety_settings = [
|
104 |
+
{"category": c, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
|
105 |
+
for c in [
|
106 |
+
"HARM_CATEGORY_HARASSMENT",
|
107 |
+
"HARM_CATEGORY_HATE_SPEECH",
|
108 |
+
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
109 |
+
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
110 |
+
]
|
111 |
+
]
|
112 |
+
|
113 |
+
|
114 |
+
@app.post("/generate_brief")
|
115 |
+
async def generate_brief(request: BriefRequest):
|
116 |
+
if not GOOGLE_API_KEY:
|
117 |
+
raise HTTPException(status_code=500, detail="Google API Key not configured.")
|
118 |
+
logger.info(
|
119 |
+
f"Generating brief for query: '{request.user_query}' using Gemini model {GEMINI_MODEL_NAME}"
|
120 |
+
)
|
121 |
+
|
122 |
+
docs_context = (
|
123 |
+
"\n".join(request.retrieved_docs[:2])
|
124 |
+
if request.retrieved_docs
|
125 |
+
else "No relevant context from documents found."
|
126 |
+
)
|
127 |
+
|
128 |
+
full_prompt = construct_gemini_prompt(
|
129 |
+
user_query=request.user_query,
|
130 |
+
analysis_data=request.analysis,
|
131 |
+
docs_context=docs_context,
|
132 |
+
)
|
133 |
+
logger.debug(f"Full prompt for Gemini:\n{full_prompt}")
|
134 |
+
|
135 |
+
try:
|
136 |
+
model = genai.GenerativeModel(
|
137 |
+
model_name=GEMINI_MODEL_NAME,
|
138 |
+
generation_config=generation_config,
|
139 |
+
safety_settings=safety_settings,
|
140 |
+
)
|
141 |
+
max_retries = 1
|
142 |
+
retry_delay_seconds = 10
|
143 |
+
for attempt in range(max_retries + 1):
|
144 |
+
try:
|
145 |
+
response = await model.generate_content_async(full_prompt)
|
146 |
+
|
147 |
+
if not response.parts:
|
148 |
+
if (
|
149 |
+
response.prompt_feedback
|
150 |
+
and response.prompt_feedback.block_reason
|
151 |
+
):
|
152 |
+
block_reason_message = (
|
153 |
+
response.prompt_feedback.block_reason_message
|
154 |
+
or "Unknown safety block"
|
155 |
+
)
|
156 |
+
logger.error(
|
157 |
+
f"Gemini content generation blocked. Reason: {block_reason_message}"
|
158 |
+
)
|
159 |
+
raise HTTPException(
|
160 |
+
status_code=400,
|
161 |
+
detail=f"Content generation blocked: {block_reason_message}",
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
logger.error("Gemini response has no parts (empty content).")
|
165 |
+
|
166 |
+
if attempt == max_retries:
|
167 |
+
raise HTTPException(
|
168 |
+
status_code=500,
|
169 |
+
detail="Gemini returned empty content after retries.",
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
logger.warning(
|
173 |
+
f"Gemini returned empty content, attempt {attempt+1}/{max_retries+1}. Retrying..."
|
174 |
+
)
|
175 |
+
await asyncio.sleep(retry_delay_seconds)
|
176 |
+
continue
|
177 |
+
|
178 |
+
brief_text = response.text
|
179 |
+
logger.info("Gemini content generated successfully.")
|
180 |
+
return {"brief": brief_text}
|
181 |
+
|
182 |
+
except (
|
183 |
+
genai.types.generation_types.BlockedPromptException,
|
184 |
+
genai.types.generation_types.StopCandidateException,
|
185 |
+
) as sce_bpe:
|
186 |
+
logger.error(
|
187 |
+
f"Gemini generation issue on attempt {attempt+1}: {sce_bpe}"
|
188 |
+
)
|
189 |
+
raise HTTPException(
|
190 |
+
status_code=400, detail=f"Gemini generation issue: {sce_bpe}"
|
191 |
+
)
|
192 |
+
except Exception as e:
|
193 |
+
logger.error(
|
194 |
+
f"Error during Gemini generation on attempt {attempt+1}: {type(e).__name__} - {e}"
|
195 |
+
)
|
196 |
+
if (
|
197 |
+
"rate limit" in str(e).lower()
|
198 |
+
or "quota" in str(e).lower()
|
199 |
+
or "429" in str(e)
|
200 |
+
or "resource_exhausted" in str(e).lower()
|
201 |
+
):
|
202 |
+
if attempt < max_retries:
|
203 |
+
wait_time = retry_delay_seconds * (2**attempt)
|
204 |
+
logger.info(f"Rate limit likely. Retrying in {wait_time}s...")
|
205 |
+
await asyncio.sleep(wait_time)
|
206 |
+
continue
|
207 |
+
else:
|
208 |
+
logger.error("Max retries reached for rate limit.")
|
209 |
+
raise HTTPException(
|
210 |
+
status_code=429,
|
211 |
+
detail=f"Gemini API rate limit/quota exceeded: {e}",
|
212 |
+
)
|
213 |
+
elif attempt < max_retries:
|
214 |
+
await asyncio.sleep(retry_delay_seconds)
|
215 |
+
continue
|
216 |
+
else:
|
217 |
+
raise HTTPException(
|
218 |
+
status_code=500,
|
219 |
+
detail=f"Failed to generate brief with Gemini: {e}",
|
220 |
+
)
|
221 |
+
|
222 |
+
raise HTTPException(
|
223 |
+
status_code=500, detail="Brief generation failed after all attempts."
|
224 |
+
)
|
225 |
+
except HTTPException:
|
226 |
+
raise
|
227 |
+
except Exception as e:
|
228 |
+
logger.error(f"Critical error in /generate_brief: {e}", exc_info=True)
|
229 |
+
raise HTTPException(
|
230 |
+
status_code=500, detail=f"Critical failure in brief generation: {e}"
|
231 |
+
)
|
agents/retriever_agent.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import List, Dict, Optional
|
4 |
+
|
5 |
+
|
6 |
+
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
7 |
+
import os
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
from langchain_community.vectorstores import FAISS
|
10 |
+
from langchain_core.embeddings import Embeddings
|
11 |
+
import logging
|
12 |
+
|
13 |
+
load_dotenv()
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
16 |
+
)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
app = FastAPI(title="Retriever Agent")
|
20 |
+
|
21 |
+
FAISS_INDEX_PATH = os.getenv(
|
22 |
+
"FAISS_INDEX_PATH", "/app/faiss_index_store"
|
23 |
+
) # Path inside container
|
24 |
+
|
25 |
+
EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "all-MiniLM-L6-v2")
|
26 |
+
|
27 |
+
|
28 |
+
embedding_model_instance: Optional[Embeddings] = None
|
29 |
+
vectorstore_instance: Optional[FAISS] = None
|
30 |
+
|
31 |
+
|
32 |
+
def get_embedding_model() -> Embeddings:
|
33 |
+
"""Initialize and return the SentenceTransformerEmbeddings model."""
|
34 |
+
global embedding_model_instance
|
35 |
+
if embedding_model_instance is None:
|
36 |
+
try:
|
37 |
+
logger.info(
|
38 |
+
f"Loading SentenceTransformerEmbeddings with model: {EMBEDDING_MODEL_NAME}"
|
39 |
+
)
|
40 |
+
|
41 |
+
embedding_model_instance = SentenceTransformerEmbeddings(
|
42 |
+
model_name=EMBEDDING_MODEL_NAME
|
43 |
+
)
|
44 |
+
logger.info(
|
45 |
+
f"SentenceTransformerEmbeddings model '{EMBEDDING_MODEL_NAME}' loaded successfully."
|
46 |
+
)
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(
|
49 |
+
f"Error loading SentenceTransformerEmbeddings model '{EMBEDDING_MODEL_NAME}': {e}",
|
50 |
+
exc_info=True,
|
51 |
+
)
|
52 |
+
raise RuntimeError(f"Could not load embedding model: {e}")
|
53 |
+
return embedding_model_instance
|
54 |
+
|
55 |
+
|
56 |
+
def get_vectorstore() -> FAISS:
|
57 |
+
"""Load or create the FAISS vector store."""
|
58 |
+
global vectorstore_instance
|
59 |
+
if vectorstore_instance is None:
|
60 |
+
emb_model = get_embedding_model()
|
61 |
+
if os.path.exists(FAISS_INDEX_PATH) and os.path.isdir(FAISS_INDEX_PATH):
|
62 |
+
try:
|
63 |
+
logger.info(
|
64 |
+
f"Attempting to load FAISS index from {FAISS_INDEX_PATH}..."
|
65 |
+
)
|
66 |
+
vectorstore_instance = FAISS.load_local(
|
67 |
+
FAISS_INDEX_PATH,
|
68 |
+
emb_model,
|
69 |
+
allow_dangerous_deserialization=True,
|
70 |
+
)
|
71 |
+
logger.info(
|
72 |
+
f"FAISS index loaded from {FAISS_INDEX_PATH}. Documents: {vectorstore_instance.index.ntotal if vectorstore_instance.index else 'N/A'}"
|
73 |
+
)
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(
|
76 |
+
f"Error loading FAISS index from {FAISS_INDEX_PATH}: {e}",
|
77 |
+
exc_info=True,
|
78 |
+
)
|
79 |
+
logger.warning("Creating a new FAISS index due to loading error.")
|
80 |
+
try:
|
81 |
+
vectorstore_instance = FAISS.from_texts(
|
82 |
+
texts=["Initial dummy document for FAISS."],
|
83 |
+
embedding=emb_model,
|
84 |
+
)
|
85 |
+
vectorstore_instance.save_local(FAISS_INDEX_PATH)
|
86 |
+
logger.info(
|
87 |
+
f"New FAISS index created with dummy doc and saved to {FAISS_INDEX_PATH}"
|
88 |
+
)
|
89 |
+
except Exception as create_e:
|
90 |
+
logger.error(
|
91 |
+
f"Failed to create new FAISS index: {create_e}", exc_info=True
|
92 |
+
)
|
93 |
+
raise RuntimeError(f"Could not create new FAISS index: {create_e}")
|
94 |
+
else:
|
95 |
+
logger.info(
|
96 |
+
f"FAISS index path {FAISS_INDEX_PATH} not found or invalid. Creating new index."
|
97 |
+
)
|
98 |
+
try:
|
99 |
+
vectorstore_instance = FAISS.from_texts(
|
100 |
+
texts=["Initial dummy document for FAISS."], embedding=emb_model
|
101 |
+
)
|
102 |
+
vectorstore_instance.save_local(FAISS_INDEX_PATH)
|
103 |
+
logger.info(f"New FAISS index created and saved to {FAISS_INDEX_PATH}")
|
104 |
+
except Exception as create_e:
|
105 |
+
logger.error(
|
106 |
+
f"Failed to create new FAISS index: {create_e}", exc_info=True
|
107 |
+
)
|
108 |
+
raise RuntimeError(f"Could not create new FAISS index: {create_e}")
|
109 |
+
return vectorstore_instance
|
110 |
+
|
111 |
+
|
112 |
+
class IndexRequest(BaseModel):
|
113 |
+
docs: List[str]
|
114 |
+
|
115 |
+
|
116 |
+
class RetrieveRequest(BaseModel):
|
117 |
+
query: str
|
118 |
+
top_k: int = 3
|
119 |
+
|
120 |
+
|
121 |
+
@app.post("/index")
|
122 |
+
def index_docs(request: IndexRequest):
|
123 |
+
try:
|
124 |
+
vecstore = get_vectorstore()
|
125 |
+
if not request.docs:
|
126 |
+
logger.warning("No documents provided for indexing.")
|
127 |
+
return {
|
128 |
+
"status": "no documents provided",
|
129 |
+
"num_docs_in_store": vecstore.index.ntotal if vecstore.index else 0,
|
130 |
+
}
|
131 |
+
logger.info(f"Indexing {len(request.docs)} new documents.")
|
132 |
+
vecstore.add_texts(texts=request.docs)
|
133 |
+
vecstore.save_local(FAISS_INDEX_PATH)
|
134 |
+
logger.info(
|
135 |
+
f"Index updated and saved. Total documents in store: {vecstore.index.ntotal}"
|
136 |
+
)
|
137 |
+
return {"status": "indexed", "num_docs_in_store": vecstore.index.ntotal}
|
138 |
+
except Exception as e:
|
139 |
+
logger.error(f"Error during indexing: {e}", exc_info=True)
|
140 |
+
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
|
141 |
+
|
142 |
+
|
143 |
+
@app.post("/retrieve")
|
144 |
+
def retrieve(request: RetrieveRequest):
|
145 |
+
try:
|
146 |
+
vecstore = get_vectorstore()
|
147 |
+
if not vecstore.index or vecstore.index.ntotal == 0:
|
148 |
+
logger.warning(
|
149 |
+
"Vector store is empty or index not initialized. Cannot retrieve."
|
150 |
+
)
|
151 |
+
return {
|
152 |
+
"results": [],
|
153 |
+
"message": "Vector store is empty. Index documents first.",
|
154 |
+
}
|
155 |
+
|
156 |
+
if vecstore.index.ntotal == 1:
|
157 |
+
|
158 |
+
try:
|
159 |
+
first_doc_id = list(vecstore.docstore._dict.keys())[0]
|
160 |
+
first_doc_content = vecstore.docstore._dict[first_doc_id].page_content
|
161 |
+
if "Initial dummy document for FAISS" in first_doc_content:
|
162 |
+
logger.warning(
|
163 |
+
"Vector store contains only the initial dummy document."
|
164 |
+
)
|
165 |
+
|
166 |
+
except Exception:
|
167 |
+
logger.warning(
|
168 |
+
"Could not inspect docstore for dummy document, proceeding with retrieval."
|
169 |
+
)
|
170 |
+
|
171 |
+
logger.info(
|
172 |
+
f"Retrieving documents for query: '{request.query}' (top_k={request.top_k}). Total docs: {vecstore.index.ntotal}"
|
173 |
+
)
|
174 |
+
results_with_scores = vecstore.similarity_search_with_score(
|
175 |
+
query=request.query, k=request.top_k
|
176 |
+
)
|
177 |
+
formatted_results = [
|
178 |
+
{"doc": doc.page_content, "score": float(score)}
|
179 |
+
for doc, score in results_with_scores
|
180 |
+
]
|
181 |
+
logger.info(f"Retrieved {len(formatted_results)} results.")
|
182 |
+
return {"results": formatted_results}
|
183 |
+
except Exception as e:
|
184 |
+
logger.error(f"Error during retrieval: {e}", exc_info=True)
|
185 |
+
|
186 |
+
raise HTTPException(status_code=500, detail=f"Retrieval failed: {str(e)}")
|
agents/scraping_agent.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from fastapi import FastAPI, HTTPException, Query, status
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import List, Optional, Dict, Any
|
5 |
+
|
6 |
+
|
7 |
+
from data_ingestion.scraping_loader import (
|
8 |
+
get_earnings_surprises,
|
9 |
+
FMPError,
|
10 |
+
)
|
11 |
+
import logging
|
12 |
+
|
13 |
+
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
app = FastAPI(title="Scraping Agent (FMP Earnings)")
|
18 |
+
|
19 |
+
class FilingRequest(BaseModel):
|
20 |
+
ticker: str
|
21 |
+
filing_type: Optional[str] = "earnings_surprise"
|
22 |
+
start_date: Optional[str] = None
|
23 |
+
end_date: Optional[str] = None
|
24 |
+
|
25 |
+
|
26 |
+
@app.post("/get_filings")
|
27 |
+
def get_filings(request: FilingRequest):
|
28 |
+
"""
|
29 |
+
Fetches filings (earnings surprise) by calling the data_ingestion layer.
|
30 |
+
"""
|
31 |
+
if request.filing_type != "earnings_surprise":
|
32 |
+
raise HTTPException(
|
33 |
+
status_code=400,
|
34 |
+
detail=f"Only 'earnings_surprise' filing_type supported in demo, received '{request.filing_type}'.",
|
35 |
+
)
|
36 |
+
|
37 |
+
try:
|
38 |
+
|
39 |
+
earnings_data_list = get_earnings_surprises(request.ticker)
|
40 |
+
|
41 |
+
return {
|
42 |
+
"ticker": request.ticker,
|
43 |
+
"filing_type": request.filing_type,
|
44 |
+
"data": earnings_data_list,
|
45 |
+
}
|
46 |
+
|
47 |
+
except (requests.exceptions.RequestException, FMPError) as e:
|
48 |
+
|
49 |
+
error_msg = f"Error fetching filings for {request.ticker}: {e}"
|
50 |
+
logger.error(error_msg)
|
51 |
+
raise HTTPException(
|
52 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error_msg
|
53 |
+
)
|
54 |
+
except Exception as e:
|
55 |
+
error_msg = f"An unexpected error occurred processing {request.ticker}: {e}"
|
56 |
+
logger.error(error_msg)
|
57 |
+
raise HTTPException(
|
58 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error_msg
|
59 |
+
)
|
agents/voice_agent.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# agents/voice_agent/main.py
|
2 |
+
|
3 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
4 |
+
from fastapi.responses import Response # Import Response for returning audio bytes
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from gtts import gTTS
|
7 |
+
import tempfile
|
8 |
+
import os
|
9 |
+
import logging
|
10 |
+
from faster_whisper import WhisperModel # For STT
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
|
13 |
+
# Load environment variables
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
app = FastAPI(title="Voice Agent")
|
21 |
+
|
22 |
+
# Get Whisper model size from environment
|
23 |
+
WHISPER_MODEL_SIZE = os.getenv("WHISPER_MODEL_SIZE", "small") # Default to 'small'
|
24 |
+
# Initialize Whisper model once on startup
|
25 |
+
try:
|
26 |
+
# Using cpu is generally safer for deployment unless you have a specific GPU setup
|
27 |
+
whisper_model = WhisperModel(WHISPER_MODEL_SIZE, device="cpu")
|
28 |
+
logger.info(f"Whisper model '{WHISPER_MODEL_SIZE}' loaded successfully on CPU.")
|
29 |
+
except Exception as e:
|
30 |
+
logger.error(f"Error loading Whisper model '{WHISPER_MODEL_SIZE}': {e}")
|
31 |
+
# Depending on criticality, you might raise here or handle gracefully
|
32 |
+
whisper_model = None # Set to None if loading failed
|
33 |
+
|
34 |
+
|
35 |
+
class TTSRequest(BaseModel):
|
36 |
+
text: str
|
37 |
+
lang: str = "en"
|
38 |
+
|
39 |
+
|
40 |
+
@app.post("/stt")
|
41 |
+
async def stt(audio: UploadFile = File(...)):
|
42 |
+
"""
|
43 |
+
Performs Speech-to-Text on an uploaded audio file.
|
44 |
+
"""
|
45 |
+
if whisper_model is None:
|
46 |
+
raise HTTPException(status_code=503, detail="STT model not loaded.")
|
47 |
+
|
48 |
+
logger.info(f"Received audio file for STT: {audio.filename}")
|
49 |
+
|
50 |
+
# Save uploaded audio file to a temporary location
|
51 |
+
# Use .with_suffix('.wav') explicitly if needed, although whisper handles formats
|
52 |
+
suffix = os.path.splitext(audio.filename)[1] if audio.filename else ".wav"
|
53 |
+
tmp_path = None
|
54 |
+
try:
|
55 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
56 |
+
audio_content = await audio.read()
|
57 |
+
tmp.write(audio_content)
|
58 |
+
tmp_path = tmp.name
|
59 |
+
logger.info(f"Audio saved to temporary file: {tmp_path}")
|
60 |
+
|
61 |
+
# Transcribe using faster-whisper
|
62 |
+
# max_int16 ensures compatibility, adjust as needed
|
63 |
+
segments, info = whisper_model.transcribe(
|
64 |
+
tmp_path, language=info.language if "info" in locals() else None
|
65 |
+
)
|
66 |
+
transcript = " ".join([seg.text for seg in segments]).strip()
|
67 |
+
logger.info(f"Transcription complete. Transcript: '{transcript}'")
|
68 |
+
|
69 |
+
return {"transcript": transcript}
|
70 |
+
|
71 |
+
except Exception as e:
|
72 |
+
logger.error(f"Error during STT processing: {e}")
|
73 |
+
raise HTTPException(status_code=500, detail=f"STT processing failed: {e}")
|
74 |
+
finally:
|
75 |
+
# Clean up temporary file
|
76 |
+
if tmp_path and os.path.exists(tmp_path):
|
77 |
+
os.remove(tmp_path)
|
78 |
+
logger.info(f"Temporary file removed: {tmp_path}")
|
79 |
+
|
80 |
+
|
81 |
+
@app.post("/tts")
|
82 |
+
def tts(request: TTSRequest):
|
83 |
+
"""
|
84 |
+
Performs Text-to-Speech using gTTS.
|
85 |
+
Returns the audio data as a hex string (to match original orchestrator expectation).
|
86 |
+
NOTE: Returning raw bytes with media_type='audio/mpeg' is more standard for APIs.
|
87 |
+
This implementation keeps the hex encoding to avoid changing the orchestrator.
|
88 |
+
"""
|
89 |
+
logger.info(
|
90 |
+
f"Generating TTS for text (lang={request.lang}): '{request.text[:50]}...'"
|
91 |
+
)
|
92 |
+
tmp_path = None
|
93 |
+
try:
|
94 |
+
# Create gTTS object
|
95 |
+
tts_obj = gTTS(text=request.text, lang=request.lang, slow=False)
|
96 |
+
|
97 |
+
# Save to a temporary file
|
98 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
|
99 |
+
tts_obj.save(tmp.name)
|
100 |
+
tmp_path = tmp.name
|
101 |
+
logger.info(f"TTS audio saved to temporary file: {tmp_path}")
|
102 |
+
|
103 |
+
# Read the audio file bytes
|
104 |
+
with open(tmp_path, "rb") as f:
|
105 |
+
audio_bytes = f.read()
|
106 |
+
logger.info(f"Read {len(audio_bytes)} bytes from temporary file.")
|
107 |
+
|
108 |
+
# Return as hex string as per original orchestrator expectation
|
109 |
+
audio_hex = audio_bytes.hex()
|
110 |
+
logger.info("Audio bytes converted to hex.")
|
111 |
+
|
112 |
+
return {"audio": audio_hex}
|
113 |
+
|
114 |
+
# --- Alternative (More standard API practice - requires orchestrator change) ---
|
115 |
+
# return Response(content=audio_bytes, media_type="audio/mpeg")
|
116 |
+
# ---------------------------------------------------------------------------
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
logger.error(f"Error during TTS processing: {e}")
|
120 |
+
raise HTTPException(status_code=500, detail=f"TTS processing failed: {e}")
|
121 |
+
finally:
|
122 |
+
# Clean up temporary file
|
123 |
+
if tmp_path and os.path.exists(tmp_path):
|
124 |
+
os.remove(tmp_path)
|
125 |
+
logger.info(f"Temporary file removed: {tmp_path}")
|
data_ingestion/__init__.py
ADDED
File without changes
|
data_ingestion/api_loader.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from typing import Dict, List, Optional, Any
|
5 |
+
import logging
|
6 |
+
|
7 |
+
load_dotenv()
|
8 |
+
|
9 |
+
logging.basicConfig(
|
10 |
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
11 |
+
)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
FMP_API_KEY = os.getenv("FMP_API_KEY")
|
15 |
+
ALPHAVANTAGE_API_KEY = os.getenv("ALPHAVANTAGE_API_KEY")
|
16 |
+
|
17 |
+
FMP_BASE_URL = "https://financialmodelingprep.com/api/v3"
|
18 |
+
ALPHAVANTAGE_BASE_URL = "https://www.alphavantage.co/query"
|
19 |
+
|
20 |
+
|
21 |
+
class DataIngestionError(Exception):
|
22 |
+
"""Custom exception for data ingestion API errors."""
|
23 |
+
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
class FMPFetchError(DataIngestionError):
|
28 |
+
"""Specific error for FMP fetching issues."""
|
29 |
+
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
class AVFetchError(DataIngestionError):
|
34 |
+
"""Specific error for AlphaVantage fetching issues."""
|
35 |
+
|
36 |
+
pass
|
37 |
+
|
38 |
+
|
39 |
+
def _fetch_from_fmp(ticker: str, api_key: str) -> Dict[str, Dict[str, Any]]:
|
40 |
+
"""Internal function to fetch data from FMP. Uses /historical-price-full/ as recommended."""
|
41 |
+
|
42 |
+
endpoint = f"{FMP_BASE_URL}/historical-price-full/{ticker}"
|
43 |
+
params = {"apikey": api_key}
|
44 |
+
logger.info(
|
45 |
+
f"Fetching historical daily data for {ticker} from FMP (using /historical-price-full/)."
|
46 |
+
)
|
47 |
+
try:
|
48 |
+
response = requests.get(endpoint, params=params, timeout=30)
|
49 |
+
response.raise_for_status()
|
50 |
+
data = response.json()
|
51 |
+
|
52 |
+
if isinstance(data, dict):
|
53 |
+
|
54 |
+
if "Error Message" in data:
|
55 |
+
raise FMPFetchError(
|
56 |
+
f"FMP API returned error for {ticker}: {data['Error Message']}"
|
57 |
+
)
|
58 |
+
if data.get("symbol") and "historical" in data:
|
59 |
+
historical_data_list = data.get("historical")
|
60 |
+
|
61 |
+
if isinstance(historical_data_list, list):
|
62 |
+
if not historical_data_list:
|
63 |
+
logger.warning(
|
64 |
+
f"FMP API returned empty historical data list for {ticker} (from /historical-price-full/)."
|
65 |
+
)
|
66 |
+
return {}
|
67 |
+
|
68 |
+
prices_dict: Dict[str, Dict[str, Any]] = {}
|
69 |
+
for record in historical_data_list:
|
70 |
+
if isinstance(record, dict) and "date" in record:
|
71 |
+
prices_dict[record["date"]] = record
|
72 |
+
else:
|
73 |
+
logger.warning(
|
74 |
+
f"Skipping invalid FMP record format for {ticker}: {record}"
|
75 |
+
)
|
76 |
+
logger.info(
|
77 |
+
f"Successfully fetched and formatted {len(prices_dict)} historical records for {ticker} from FMP."
|
78 |
+
)
|
79 |
+
return prices_dict
|
80 |
+
else:
|
81 |
+
raise FMPFetchError(
|
82 |
+
f"FMP API historical data for {ticker} has unexpected 'historical' type: {type(historical_data_list)}"
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
raise FMPFetchError(
|
86 |
+
f"FMP API response for {ticker} (from /historical-price-full/) missing expected structure (symbol/historical keys). Response: {str(data)[:200]}"
|
87 |
+
)
|
88 |
+
|
89 |
+
elif isinstance(data, list):
|
90 |
+
if not data:
|
91 |
+
logger.warning(
|
92 |
+
f"FMP API returned empty list for {ticker} (from /historical-price-full/)."
|
93 |
+
)
|
94 |
+
return {}
|
95 |
+
if isinstance(data[0], dict) and (
|
96 |
+
"Error Message" in data[0] or "error" in data[0]
|
97 |
+
):
|
98 |
+
error_msg = data[0].get(
|
99 |
+
"Error Message", data[0].get("error", "Unknown error in list")
|
100 |
+
)
|
101 |
+
raise FMPFetchError(
|
102 |
+
f"FMP API returned error list for {ticker}: {error_msg}"
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
raise FMPFetchError(
|
106 |
+
f"FMP API returned unexpected top-level list structure for {ticker} (from /historical-price-full/). Response: {str(data)[:200]}"
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
raise FMPFetchError(
|
110 |
+
f"FMP API returned unexpected response type for {ticker} (from /historical-price-full/): {type(data)}. Response: {str(data)[:200]}"
|
111 |
+
)
|
112 |
+
|
113 |
+
except requests.exceptions.RequestException as e:
|
114 |
+
raise FMPFetchError(f"FMP data fetch (network) failed for {ticker}: {e}")
|
115 |
+
except Exception as e:
|
116 |
+
raise FMPFetchError(
|
117 |
+
f"FMP data fetch (processing) failed for {ticker}: {e}. Response: {str(locals().get('data', 'N/A'))[:200]}"
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
def _fetch_from_alphavantage(ticker: str, api_key: str) -> Dict[str, Dict[str, Any]]:
|
122 |
+
"""Internal function to fetch data from AlphaVantage."""
|
123 |
+
endpoint = f"{ALPHAVANTAGE_BASE_URL}/query"
|
124 |
+
params = {
|
125 |
+
"function": "TIME_SERIES_DAILY_ADJUSTED",
|
126 |
+
"symbol": ticker,
|
127 |
+
"apikey": api_key,
|
128 |
+
"outputsize": "compact",
|
129 |
+
}
|
130 |
+
logger.info(f"Fetching historical daily data for {ticker} from AlphaVantage.")
|
131 |
+
try:
|
132 |
+
response = requests.get(endpoint, params=params, timeout=30)
|
133 |
+
response.raise_for_status()
|
134 |
+
data = response.json()
|
135 |
+
|
136 |
+
if not isinstance(data, dict):
|
137 |
+
raise AVFetchError(
|
138 |
+
f"AlphaVantage API returned unexpected response type for {ticker}: {type(data)}. Expected dict. Response: {str(data)[:200]}"
|
139 |
+
)
|
140 |
+
|
141 |
+
if "Error Message" in data:
|
142 |
+
raise AVFetchError(
|
143 |
+
f"AlphaVantage API returned error for {ticker}: {data['Error Message']}"
|
144 |
+
)
|
145 |
+
if "Note" in data:
|
146 |
+
logger.warning(
|
147 |
+
f"AlphaVantage API returned note for {ticker}: {data['Note']} - treating as no data."
|
148 |
+
)
|
149 |
+
|
150 |
+
return {}
|
151 |
+
|
152 |
+
time_series_data = data.get("Time Series (Daily)")
|
153 |
+
|
154 |
+
if time_series_data is None:
|
155 |
+
|
156 |
+
if not data:
|
157 |
+
logger.warning(
|
158 |
+
f"AlphaVantage API returned an empty dictionary for {ticker}."
|
159 |
+
)
|
160 |
+
return {}
|
161 |
+
else:
|
162 |
+
raise AVFetchError(
|
163 |
+
f"AlphaVantage API response for {ticker} missing 'Time Series (Daily)' key. Response: {str(data)[:200]}"
|
164 |
+
)
|
165 |
+
|
166 |
+
if not isinstance(time_series_data, dict):
|
167 |
+
raise AVFetchError(
|
168 |
+
f"AlphaVantage API 'Time Series (Daily)' for {ticker} is not a dictionary. Type: {type(time_series_data)}. Response: {str(data)[:200]}"
|
169 |
+
)
|
170 |
+
|
171 |
+
if not time_series_data:
|
172 |
+
logger.warning(
|
173 |
+
f"AlphaVantage API returned empty time series data for {ticker}."
|
174 |
+
)
|
175 |
+
return {}
|
176 |
+
|
177 |
+
prices_dict: Dict[str, Dict[str, Any]] = {}
|
178 |
+
for date_str, values_dict in time_series_data.items():
|
179 |
+
if isinstance(values_dict, dict):
|
180 |
+
cleaned_values: Dict[str, Any] = {}
|
181 |
+
if "1. open" in values_dict:
|
182 |
+
cleaned_values["open"] = values_dict["1. open"]
|
183 |
+
if "2. high" in values_dict:
|
184 |
+
cleaned_values["high"] = values_dict["2. high"]
|
185 |
+
if "3. low" in values_dict:
|
186 |
+
cleaned_values["low"] = values_dict["3. low"]
|
187 |
+
if "4. close" in values_dict:
|
188 |
+
cleaned_values["close"] = values_dict["4. close"]
|
189 |
+
if "5. adjusted close" in values_dict:
|
190 |
+
cleaned_values["adjClose"] = values_dict["5. adjusted close"]
|
191 |
+
if "6. volume" in values_dict:
|
192 |
+
cleaned_values["volume"] = values_dict["6. volume"]
|
193 |
+
|
194 |
+
if cleaned_values:
|
195 |
+
prices_dict[date_str] = cleaned_values
|
196 |
+
else:
|
197 |
+
logger.warning(
|
198 |
+
f"AlphaVantage data for {ticker} on {date_str} missing expected price keys within daily record."
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
logger.warning(
|
202 |
+
f"Skipping invalid AlphaVantage daily record (not a dict) for {ticker} on {date_str}: {values_dict}"
|
203 |
+
)
|
204 |
+
logger.info(
|
205 |
+
f"Successfully fetched and formatted {len(prices_dict)} historical records for {ticker} from AlphaVantage."
|
206 |
+
)
|
207 |
+
return prices_dict
|
208 |
+
|
209 |
+
except requests.exceptions.RequestException as e:
|
210 |
+
raise AVFetchError(
|
211 |
+
f"AlphaVantage data fetch (network) failed for {ticker}: {e}"
|
212 |
+
)
|
213 |
+
except Exception as e:
|
214 |
+
raise AVFetchError(
|
215 |
+
f"AlphaVantage data fetch (processing) failed for {ticker}: {e}. Response: {str(locals().get('data', 'N/A'))[:200]}"
|
216 |
+
)
|
217 |
+
|
218 |
+
|
219 |
+
def get_daily_adjusted_prices(ticker: str) -> Dict[str, Dict[str, Any]]:
|
220 |
+
"""
|
221 |
+
Fetches historical daily adjusted prices for a single ticker.
|
222 |
+
Tries FMP first if key is available. If FMP fails, tries AlphaVantage if key is available.
|
223 |
+
Returns a dictionary mapping date strings to price dictionaries.
|
224 |
+
Raises DataIngestionError if no keys are configured or if both APIs fail.
|
225 |
+
"""
|
226 |
+
fmp_key_available = bool(FMP_API_KEY)
|
227 |
+
av_key_available = bool(ALPHAVANTAGE_API_KEY)
|
228 |
+
|
229 |
+
if not fmp_key_available and not av_key_available:
|
230 |
+
raise DataIngestionError(
|
231 |
+
"No API keys configured for historical price data (FMP, AlphaVantage)."
|
232 |
+
)
|
233 |
+
|
234 |
+
fmp_error_detail = None
|
235 |
+
av_error_detail = None
|
236 |
+
data_from_fmp = {}
|
237 |
+
data_from_av = {}
|
238 |
+
|
239 |
+
if fmp_key_available:
|
240 |
+
try:
|
241 |
+
data_from_fmp = _fetch_from_fmp(ticker, FMP_API_KEY)
|
242 |
+
if data_from_fmp:
|
243 |
+
return data_from_fmp
|
244 |
+
else:
|
245 |
+
|
246 |
+
fmp_error_detail = f"FMP API returned no data for {ticker}."
|
247 |
+
logger.warning(fmp_error_detail)
|
248 |
+
except FMPFetchError as e:
|
249 |
+
fmp_error_detail = str(e)
|
250 |
+
logger.error(f"FMPFetchError for {ticker}: {fmp_error_detail}")
|
251 |
+
except Exception as e:
|
252 |
+
fmp_error_detail = (
|
253 |
+
f"An unexpected error occurred during FMP fetch for {ticker}: {e}"
|
254 |
+
)
|
255 |
+
logger.error(fmp_error_detail)
|
256 |
+
|
257 |
+
if av_key_available:
|
258 |
+
try:
|
259 |
+
data_from_av = _fetch_from_alphavantage(ticker, ALPHAVANTAGE_API_KEY)
|
260 |
+
if data_from_av:
|
261 |
+
return data_from_av
|
262 |
+
else:
|
263 |
+
|
264 |
+
av_error_detail = f"AlphaVantage API returned no data for {ticker}."
|
265 |
+
logger.warning(av_error_detail)
|
266 |
+
except AVFetchError as e:
|
267 |
+
av_error_detail = str(e)
|
268 |
+
logger.error(f"AVFetchError for {ticker}: {av_error_detail}")
|
269 |
+
except Exception as e:
|
270 |
+
av_error_detail = f"An unexpected error occurred during AlphaVantage fetch for {ticker}: {e}"
|
271 |
+
logger.error(av_error_detail)
|
272 |
+
|
273 |
+
error_messages = []
|
274 |
+
if fmp_key_available:
|
275 |
+
if fmp_error_detail:
|
276 |
+
error_messages.append(f"FMP: {fmp_error_detail}")
|
277 |
+
elif not data_from_fmp:
|
278 |
+
error_messages.append(f"FMP: Returned no data for {ticker}.")
|
279 |
+
|
280 |
+
if av_key_available:
|
281 |
+
if av_error_detail:
|
282 |
+
error_messages.append(f"AlphaVantage: {av_error_detail}")
|
283 |
+
elif not data_from_av:
|
284 |
+
error_messages.append(f"AlphaVantage: Returned no data for {ticker}.")
|
285 |
+
|
286 |
+
providers_tried = []
|
287 |
+
if fmp_key_available:
|
288 |
+
providers_tried.append("FMP")
|
289 |
+
if av_key_available:
|
290 |
+
providers_tried.append("AlphaVantage")
|
291 |
+
|
292 |
+
final_message = f"Failed to fetch historical data for {ticker} after trying {', '.join(providers_tried) if providers_tried else 'available providers'}."
|
293 |
+
if error_messages:
|
294 |
+
final_message += " Details: " + "; ".join(error_messages)
|
295 |
+
else:
|
296 |
+
final_message += " No data was returned from any attempted source."
|
297 |
+
|
298 |
+
raise DataIngestionError(final_message)
|
data_ingestion/document_loader.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any
|
2 |
+
import logging
|
3 |
+
|
4 |
+
logging.basicConfig(level=logging.INFO)
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
|
8 |
+
def load_text_documents(filepaths: List[str]) -> List[str]:
|
9 |
+
"""
|
10 |
+
Loads text content from a list of file paths. (Placeholder)
|
11 |
+
"""
|
12 |
+
loaded_docs = []
|
13 |
+
logger.info(
|
14 |
+
f"Attempting to load documents from {len(filepaths)} file paths (placeholder)."
|
15 |
+
)
|
16 |
+
for path in filepaths:
|
17 |
+
try:
|
18 |
+
|
19 |
+
if os.path.exists(path):
|
20 |
+
with open(path, "r", encoding="utf-8") as f:
|
21 |
+
content = f.read()
|
22 |
+
loaded_docs.append(content)
|
23 |
+
logger.info(f"Successfully loaded content from {path} (simulated).")
|
24 |
+
else:
|
25 |
+
logger.warning(f"File not found: {path}")
|
26 |
+
loaded_docs.append(
|
27 |
+
f"Could not load content from {path}: File not found."
|
28 |
+
)
|
29 |
+
except Exception as e:
|
30 |
+
logger.error(f"Error loading document from {path}: {e}")
|
31 |
+
loaded_docs.append(f"Could not load content from {path}: {e}")
|
32 |
+
|
33 |
+
return loaded_docs
|
data_ingestion/scraping_loader.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from typing import List, Dict, Any
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
FMP_API_KEY = os.getenv("FMP_API_KEY")
|
15 |
+
if not FMP_API_KEY:
|
16 |
+
logger.warning("FMP_API_KEY not found. FMP calls will fail.")
|
17 |
+
|
18 |
+
|
19 |
+
FMP_BASE_URL = "https://financialmodelingprep.com/api/v3"
|
20 |
+
|
21 |
+
|
22 |
+
class FMPError(Exception):
|
23 |
+
"""Custom exception for FMP API errors."""
|
24 |
+
|
25 |
+
pass
|
26 |
+
|
27 |
+
|
28 |
+
def get_earnings_surprises(ticker: str) -> List[Dict[str, Any]]:
|
29 |
+
"""
|
30 |
+
Fetches earnings surprise data for a single ticker from Financial Modeling Prep.
|
31 |
+
Returns a list of earnings surprise records.
|
32 |
+
Raises FMPError on API-specific issues.
|
33 |
+
Raises requests.RequestException on network issues.
|
34 |
+
"""
|
35 |
+
if not FMP_API_KEY:
|
36 |
+
raise FMPError("FMP API Key not configured.")
|
37 |
+
|
38 |
+
endpoint = f"{FMP_BASE_URL}/earning_surprise/{ticker}"
|
39 |
+
params = {"apikey": FMP_API_KEY}
|
40 |
+
|
41 |
+
logger.info(f"Fetching earnings surprise data for {ticker} from FMP.")
|
42 |
+
response = requests.get(endpoint, params=params, timeout=30)
|
43 |
+
response.raise_for_status()
|
44 |
+
data = response.json()
|
45 |
+
|
46 |
+
if isinstance(data, list):
|
47 |
+
return data
|
48 |
+
else:
|
49 |
+
|
50 |
+
logger.error(f"Unexpected FMP response structure for {ticker}: {data}")
|
51 |
+
|
52 |
+
if isinstance(data, dict) and data.get("error"):
|
53 |
+
raise FMPError(f"FMP API returned error for {ticker}: {data['error']}")
|
54 |
+
|
55 |
+
if isinstance(data, dict) and not data:
|
56 |
+
logger.warning(
|
57 |
+
f"FMP API returned empty response for {ticker}, potentially no data."
|
58 |
+
)
|
59 |
+
return []
|
60 |
+
raise FMPError(f"Unexpected API response structure for {ticker}.")
|
docker-compose.yaml
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.8'
|
2 |
+
|
3 |
+
services:
|
4 |
+
api_agent:
|
5 |
+
build:
|
6 |
+
context: .
|
7 |
+
dockerfile: Dockerfile.fastapi
|
8 |
+
command: uvicorn agents.api_agent:app --host 0.0.0.0 --port 8001
|
9 |
+
volumes:
|
10 |
+
- .:/app
|
11 |
+
ports:
|
12 |
+
- "8001:8001"
|
13 |
+
environment:
|
14 |
+
- FMP_API_KEY=${FMP_API_KEY}
|
15 |
+
- ALPHAVANTAGE_API_KEY=${ALPHAVANTAGE_API_KEY}
|
16 |
+
networks:
|
17 |
+
- agent_network
|
18 |
+
|
19 |
+
scraping_agent:
|
20 |
+
build:
|
21 |
+
context: .
|
22 |
+
dockerfile: Dockerfile.fastapi
|
23 |
+
command: uvicorn agents.scraping_agent:app --host 0.0.0.0 --port 8002
|
24 |
+
volumes:
|
25 |
+
- .:/app
|
26 |
+
ports:
|
27 |
+
- "8002:8002"
|
28 |
+
environment:
|
29 |
+
- FMP_API_KEY=${FMP_API_KEY}
|
30 |
+
networks:
|
31 |
+
- agent_network
|
32 |
+
|
33 |
+
retriever_agent:
|
34 |
+
build:
|
35 |
+
context: .
|
36 |
+
dockerfile: Dockerfile.fastapi
|
37 |
+
command: uvicorn agents.retriever_agent:app --host 0.0.0.0 --port 8003
|
38 |
+
volumes:
|
39 |
+
- .:/app
|
40 |
+
- faiss_index_volume:/app/faiss_index_store
|
41 |
+
ports:
|
42 |
+
- "8003:8003"
|
43 |
+
environment:
|
44 |
+
- FAISS_INDEX_PATH=/app/faiss_index_store
|
45 |
+
networks:
|
46 |
+
- agent_network
|
47 |
+
|
48 |
+
analysis_agent:
|
49 |
+
build:
|
50 |
+
context: .
|
51 |
+
dockerfile: Dockerfile.fastapi
|
52 |
+
command: uvicorn agents.analysis_agent:app --host 0.0.0.0 --port 8004
|
53 |
+
volumes:
|
54 |
+
- .:/app
|
55 |
+
ports:
|
56 |
+
- "8004:8004"
|
57 |
+
networks:
|
58 |
+
- agent_network
|
59 |
+
|
60 |
+
language_agent:
|
61 |
+
build:
|
62 |
+
context: .
|
63 |
+
dockerfile: Dockerfile.fastapi
|
64 |
+
command: uvicorn agents.language_agent:app --host 0.0.0.0 --port 8005
|
65 |
+
volumes:
|
66 |
+
- .:/app
|
67 |
+
ports:
|
68 |
+
- "8005:8005"
|
69 |
+
environment:
|
70 |
+
- GOOGLE_API_KEY=${GOOGLE_API_KEY}
|
71 |
+
- GEMINI_MODEL_NAME=${GEMINI_MODEL_NAME:-gemini-1.5-flash-latest}
|
72 |
+
networks:
|
73 |
+
- agent_network
|
74 |
+
|
75 |
+
voice_agent:
|
76 |
+
build:
|
77 |
+
context: .
|
78 |
+
dockerfile: Dockerfile.fastapi
|
79 |
+
command: uvicorn agents.voice_agent:app --host 0.0.0.0 --port 8006
|
80 |
+
volumes:
|
81 |
+
- .:/app
|
82 |
+
ports:
|
83 |
+
- "8006:8006"
|
84 |
+
environment:
|
85 |
+
- WHISPER_MODEL_SIZE=${WHISPER_MODEL_SIZE:-small}
|
86 |
+
networks:
|
87 |
+
- agent_network
|
88 |
+
|
89 |
+
orchestrator:
|
90 |
+
build:
|
91 |
+
context: .
|
92 |
+
dockerfile: Dockerfile.fastapi
|
93 |
+
command: uvicorn orchestrator.main:app --host 0.0.0.0 --port 8000
|
94 |
+
volumes:
|
95 |
+
- .:/app
|
96 |
+
- ./example_portfolio.json:/app/example_portfolio.json
|
97 |
+
ports:
|
98 |
+
- "8000:8000"
|
99 |
+
environment:
|
100 |
+
- AGENT_API_URL=http://api_agent:8001
|
101 |
+
- AGENT_SCRAPING_URL=http://scraping_agent:8002
|
102 |
+
- AGENT_RETRIEVER_URL=http://retriever_agent:8003
|
103 |
+
- AGENT_ANALYSIS_URL=http://analysis_agent:8004
|
104 |
+
- AGENT_LANGUAGE_URL=http://language_agent:8005
|
105 |
+
- AGENT_VOICE_URL=http://voice_agent:8006
|
106 |
+
depends_on:
|
107 |
+
- api_agent
|
108 |
+
- scraping_agent
|
109 |
+
- retriever_agent
|
110 |
+
- analysis_agent
|
111 |
+
- language_agent
|
112 |
+
- voice_agent
|
113 |
+
networks:
|
114 |
+
- agent_network
|
115 |
+
|
116 |
+
streamlit_app:
|
117 |
+
build:
|
118 |
+
context: .
|
119 |
+
dockerfile: Dockerfile.streamlit
|
120 |
+
command: streamlit run streamlit/app.py --server.port=8501 --server.address=0.0.0.0 --browser.gatherUsageStats=false
|
121 |
+
volumes:
|
122 |
+
- ./streamlit:/app/streamlit
|
123 |
+
- ./example_portfolio.json:/app/example_portfolio.json
|
124 |
+
ports:
|
125 |
+
- "8501:8501"
|
126 |
+
environment:
|
127 |
+
- ORCHESTRATOR_URL=http://orchestrator:8000
|
128 |
+
depends_on:
|
129 |
+
- orchestrator
|
130 |
+
networks:
|
131 |
+
- agent_network
|
132 |
+
|
133 |
+
volumes:
|
134 |
+
faiss_index_volume:
|
135 |
+
|
136 |
+
networks:
|
137 |
+
agent_network:
|
138 |
+
driver: bridge
|
example_portfolio.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"TSM": {"weight": 0.22, "country": "Taiwan", "sector": "Technology", "name": "TSMC ADR"},
|
3 |
+
"AAPL": {"weight": 0.15, "country": "USA", "sector": "Technology", "name": "Apple Inc."},
|
4 |
+
"MSFT": {"weight": 0.10, "country": "USA", "sector": "Technology", "name": "Microsoft Corp."},
|
5 |
+
"JNJ": {"weight": 0.08, "country": "USA", "sector": "Healthcare", "name": "Johnson & Johnson"},
|
6 |
+
"BABA": {"weight": 0.05, "country": "China", "sector": "Technology", "name": "Alibaba Group ADR"},
|
7 |
+
"ASML": {"weight": 0.07, "country": "Netherlands", "sector": "Technology", "name": "ASML Holding NV ADR (Europe Tech)"},
|
8 |
+
"NVDA": {"weight": 0.12, "country": "USA", "sector": "Technology", "name": "NVIDIA Corp."},
|
9 |
+
"GOOGL": {"weight": 0.11, "country": "USA", "sector": "Technology", "name": "Alphabet Inc. (Google)"},
|
10 |
+
"INTC": {"weight": 0.10, "country": "USA", "sector": "Technology", "name": "Intel Corp."}
|
11 |
+
}
|
orchestrator/main.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, status
|
2 |
+
from pydantic import BaseModel
|
3 |
+
import httpx
|
4 |
+
import os
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from langgraph.graph import StateGraph, END
|
7 |
+
from typing import Dict, List, Optional, Any, Union
|
8 |
+
import logging
|
9 |
+
import json
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
logging.basicConfig(
|
13 |
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
14 |
+
)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
app = FastAPI(title="Orchestrator (Generalized)")
|
18 |
+
|
19 |
+
AGENT_API_URL = os.getenv("AGENT_API_URL", "http://localhost:8001")
|
20 |
+
AGENT_SCRAPING_URL = os.getenv("AGENT_SCRAPING_URL", "http://localhost:8002")
|
21 |
+
AGENT_RETRIEVER_URL = os.getenv("AGENT_RETRIEVER_URL", "http://localhost:8003")
|
22 |
+
AGENT_ANALYSIS_URL = os.getenv("AGENT_ANALYSIS_URL", "http://localhost:8004")
|
23 |
+
AGENT_LANGUAGE_URL = os.getenv("AGENT_LANGUAGE_URL", "http://localhost:8005")
|
24 |
+
AGENT_VOICE_URL = os.getenv("AGENT_VOICE_URL", "http://localhost:8006")
|
25 |
+
|
26 |
+
|
27 |
+
class EarningsSurpriseRecordState(BaseModel):
|
28 |
+
date: str
|
29 |
+
symbol: str
|
30 |
+
actual: Union[float, int, str, None] = None
|
31 |
+
estimate: Union[float, int, str, None] = None
|
32 |
+
difference: Union[float, int, str, None] = None
|
33 |
+
surprisePercentage: Union[float, int, str, None] = None
|
34 |
+
|
35 |
+
|
36 |
+
class MarketBriefState(BaseModel):
|
37 |
+
audio_input: Optional[bytes] = None
|
38 |
+
user_text: Optional[str] = None
|
39 |
+
nlu_results: Optional[Dict[str, str]] = None
|
40 |
+
|
41 |
+
target_tickers_for_data_fetch: List[str] = []
|
42 |
+
market_data: Optional[Dict[str, Dict[str, float]]] = None
|
43 |
+
filings: Optional[Dict[str, List[EarningsSurpriseRecordState]]] = None
|
44 |
+
|
45 |
+
indexed: bool = False
|
46 |
+
retrieved_docs: Optional[List[str]] = None
|
47 |
+
analysis: Optional[Dict[str, Any]] = None
|
48 |
+
brief: Optional[str] = None
|
49 |
+
audio_output: Optional[bytes] = None
|
50 |
+
errors: List[str] = []
|
51 |
+
warnings: List[str] = []
|
52 |
+
|
53 |
+
class Config:
|
54 |
+
arbitrary_types_allowed = True
|
55 |
+
|
56 |
+
|
57 |
+
EXAMPLE_PORTFOLIO_FILE = "example_portfolio.json"
|
58 |
+
EXAMPLE_METADATA_FILE = "example_metadata.json"
|
59 |
+
|
60 |
+
|
61 |
+
def load_example_data(file_path: str, default_data: Dict) -> Dict:
|
62 |
+
if os.path.exists(file_path):
|
63 |
+
try:
|
64 |
+
with open(file_path, "r") as f:
|
65 |
+
return json.load(f)
|
66 |
+
except Exception as e:
|
67 |
+
logger.warning(f"Could not load {file_path}: {e}. Using default.")
|
68 |
+
return default_data
|
69 |
+
|
70 |
+
|
71 |
+
EXAMPLE_PORTFOLIO = load_example_data(
|
72 |
+
EXAMPLE_PORTFOLIO_FILE,
|
73 |
+
{
|
74 |
+
"TSM": {
|
75 |
+
"weight": 0.22,
|
76 |
+
"country": "Taiwan",
|
77 |
+
"sector": "Technology",
|
78 |
+
},
|
79 |
+
"AAPL": {"weight": 0.15, "country": "USA", "sector": "Technology"},
|
80 |
+
"MSFT": {"weight": 0.10, "country": "USA", "sector": "Technology"},
|
81 |
+
"JNJ": {"weight": 0.08, "country": "USA", "sector": "Healthcare"},
|
82 |
+
"BABA": {
|
83 |
+
"weight": 0.05,
|
84 |
+
"country": "China",
|
85 |
+
"sector": "Technology",
|
86 |
+
},
|
87 |
+
},
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
async def call_agent(
|
92 |
+
client: httpx.AsyncClient,
|
93 |
+
url: str,
|
94 |
+
method: str = "post",
|
95 |
+
json_payload: Optional[Dict] = None,
|
96 |
+
files_payload: Optional[Dict] = None,
|
97 |
+
timeout: float = 60.0,
|
98 |
+
) -> Dict:
|
99 |
+
try:
|
100 |
+
logger.info(
|
101 |
+
f"Calling agent at {url} with payload keys: {list(json_payload.keys()) if json_payload else 'N/A'}"
|
102 |
+
)
|
103 |
+
if method == "post":
|
104 |
+
if json_payload:
|
105 |
+
response = await client.post(url, json=json_payload, timeout=timeout)
|
106 |
+
elif files_payload:
|
107 |
+
response = await client.post(url, files=files_payload, timeout=timeout)
|
108 |
+
else:
|
109 |
+
raise ValueError("POST request requires json_payload or files_payload.")
|
110 |
+
elif method == "get":
|
111 |
+
response = await client.get(url, params=json_payload, timeout=timeout)
|
112 |
+
else:
|
113 |
+
raise ValueError(f"Unsupported method: {method}")
|
114 |
+
|
115 |
+
response.raise_for_status()
|
116 |
+
logger.info(f"Agent at {url} returned status {response.status_code}.")
|
117 |
+
return response.json()
|
118 |
+
|
119 |
+
except httpx.ConnectError as e:
|
120 |
+
error_msg = f"Connection error calling agent at {url}: {e}"
|
121 |
+
logger.error(error_msg)
|
122 |
+
raise HTTPException(
|
123 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=error_msg
|
124 |
+
)
|
125 |
+
except httpx.RequestError as e:
|
126 |
+
error_msg = f"Request error calling agent at {url}: {e}"
|
127 |
+
logger.error(error_msg)
|
128 |
+
raise HTTPException(
|
129 |
+
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail=error_msg
|
130 |
+
)
|
131 |
+
except httpx.HTTPStatusError as e:
|
132 |
+
error_msg = f"HTTP error calling agent at {url}: {e.response.status_code} - {e.response.text[:200]}"
|
133 |
+
logger.error(error_msg)
|
134 |
+
raise HTTPException(status_code=e.response.status_code, detail=e.response.text)
|
135 |
+
except Exception as e:
|
136 |
+
error_msg = f"An unexpected error occurred calling agent at {url}: {e}"
|
137 |
+
logger.error(error_msg, exc_info=True)
|
138 |
+
raise HTTPException(
|
139 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error_msg
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
async def stt_node(state: MarketBriefState) -> MarketBriefState:
|
144 |
+
|
145 |
+
async with httpx.AsyncClient() as client:
|
146 |
+
if not state.audio_input:
|
147 |
+
state.errors.append("STT Node: No audio input provided.")
|
148 |
+
logger.error(state.errors[-1])
|
149 |
+
state.user_text = "Error: No audio provided for STT."
|
150 |
+
return state
|
151 |
+
files = {"audio": ("input.wav", state.audio_input, "audio/wav")}
|
152 |
+
try:
|
153 |
+
response_data = await call_agent(
|
154 |
+
client, f"{AGENT_VOICE_URL}/stt", files_payload=files
|
155 |
+
)
|
156 |
+
if "transcript" in response_data:
|
157 |
+
state.user_text = response_data["transcript"]
|
158 |
+
logger.info(f"STT successful. Transcript: {state.user_text[:50]}...")
|
159 |
+
else:
|
160 |
+
error_msg = f"STT agent response missing 'transcript': {response_data}"
|
161 |
+
logger.error(error_msg)
|
162 |
+
state.errors.append(error_msg)
|
163 |
+
state.user_text = "Error: STT failed."
|
164 |
+
except HTTPException as e:
|
165 |
+
state.errors.append(f"STT Node failed: {e.detail}")
|
166 |
+
state.user_text = "Error: STT service unavailable or failed."
|
167 |
+
return state
|
168 |
+
|
169 |
+
|
170 |
+
async def nlu_node(state: MarketBriefState) -> MarketBriefState:
|
171 |
+
"""(NEW) Calls an NLU process (simulated here) to extract intent."""
|
172 |
+
if not state.user_text or "Error:" in state.user_text:
|
173 |
+
state.warnings.append(
|
174 |
+
"NLU Node: Skipping due to missing or error in user_text."
|
175 |
+
)
|
176 |
+
state.nlu_results = {
|
177 |
+
"region": "Global",
|
178 |
+
"sector": "Overall Portfolio",
|
179 |
+
}
|
180 |
+
return state
|
181 |
+
|
182 |
+
logger.info(f"NLU Node: Processing query: '{state.user_text}'")
|
183 |
+
|
184 |
+
query_lower = state.user_text.lower()
|
185 |
+
region = "Global"
|
186 |
+
sector = "Overall Portfolio"
|
187 |
+
|
188 |
+
if "asia" in query_lower and "tech" in query_lower:
|
189 |
+
region = "Asia"
|
190 |
+
sector = "Technology"
|
191 |
+
logger.info("NLU Simulation: Detected 'Asia' and 'Tech'.")
|
192 |
+
elif "us" in query_lower or "usa" in query_lower or "america" in query_lower:
|
193 |
+
region = "USA"
|
194 |
+
if "tech" in query_lower:
|
195 |
+
sector = "Technology"
|
196 |
+
elif "health" in query_lower:
|
197 |
+
sector = "Healthcare"
|
198 |
+
logger.info(f"NLU Simulation: Detected Region '{region}', Sector '{sector}'.")
|
199 |
+
|
200 |
+
state.nlu_results = {"region": region, "sector": sector}
|
201 |
+
logger.info(f"NLU Node: Results: {state.nlu_results}")
|
202 |
+
|
203 |
+
target_tickers = []
|
204 |
+
portfolio_keys = list(EXAMPLE_PORTFOLIO.keys())
|
205 |
+
|
206 |
+
if region == "Global" and (
|
207 |
+
sector == "Overall Portfolio" or sector == "Overall Market"
|
208 |
+
):
|
209 |
+
target_tickers = portfolio_keys
|
210 |
+
else:
|
211 |
+
for ticker, details in EXAMPLE_PORTFOLIO.items():
|
212 |
+
matches_region = region == "Global"
|
213 |
+
if region == "Asia" and details.get("country") in [
|
214 |
+
"Taiwan",
|
215 |
+
"China",
|
216 |
+
"Korea",
|
217 |
+
"Japan",
|
218 |
+
"India",
|
219 |
+
]:
|
220 |
+
matches_region = True
|
221 |
+
elif region == "USA" and details.get("country") == "USA":
|
222 |
+
matches_region = True
|
223 |
+
|
224 |
+
matches_sector = sector == "Overall Portfolio" or sector == "Overall Market"
|
225 |
+
if sector.lower() == details.get("sector", "").lower():
|
226 |
+
matches_sector = True
|
227 |
+
|
228 |
+
if matches_region and matches_sector:
|
229 |
+
target_tickers.append(ticker)
|
230 |
+
|
231 |
+
if not target_tickers and portfolio_keys:
|
232 |
+
logger.warning(
|
233 |
+
f"NLU filtering yielded no specific tickers for {region}/{sector}, defaulting to all portfolio tickers."
|
234 |
+
)
|
235 |
+
target_tickers = portfolio_keys
|
236 |
+
state.nlu_results["region_effective"] = "Global"
|
237 |
+
state.nlu_results["sector_effective"] = "Overall Portfolio"
|
238 |
+
|
239 |
+
state.target_tickers_for_data_fetch = list(set(target_tickers))
|
240 |
+
logger.info(
|
241 |
+
f"NLU Node: Target tickers for data fetch: {state.target_tickers_for_data_fetch}"
|
242 |
+
)
|
243 |
+
if not state.target_tickers_for_data_fetch:
|
244 |
+
state.warnings.append(
|
245 |
+
"NLU Node: No target tickers identified for data fetching based on query and portfolio."
|
246 |
+
)
|
247 |
+
|
248 |
+
return state
|
249 |
+
|
250 |
+
|
251 |
+
async def api_agent_node(state: MarketBriefState) -> MarketBriefState:
|
252 |
+
if not state.target_tickers_for_data_fetch:
|
253 |
+
state.warnings.append(
|
254 |
+
"API Agent Node: No target tickers to fetch market data for. Skipping."
|
255 |
+
)
|
256 |
+
state.market_data = {}
|
257 |
+
return state
|
258 |
+
|
259 |
+
async with httpx.AsyncClient() as client:
|
260 |
+
payload = {
|
261 |
+
"tickers": state.target_tickers_for_data_fetch,
|
262 |
+
"data_type": "adjClose",
|
263 |
+
}
|
264 |
+
try:
|
265 |
+
response_data = await call_agent(
|
266 |
+
client, f"{AGENT_API_URL}/get_market_data", json_payload=payload
|
267 |
+
)
|
268 |
+
state.market_data = response_data.get("result", {})
|
269 |
+
logger.info(
|
270 |
+
f"API Agent successful. Fetched data for tickers: {list(state.market_data.keys()) if state.market_data else 'None'}"
|
271 |
+
)
|
272 |
+
if response_data.get("errors"):
|
273 |
+
state.warnings.append(
|
274 |
+
f"API Agent reported errors: {response_data['errors']}"
|
275 |
+
)
|
276 |
+
if response_data.get("warnings"):
|
277 |
+
state.warnings.extend(response_data.get("warnings", []))
|
278 |
+
|
279 |
+
except HTTPException as e:
|
280 |
+
state.errors.append(
|
281 |
+
f"API Agent Node failed for tickers {state.target_tickers_for_data_fetch}: {e.detail}"
|
282 |
+
)
|
283 |
+
state.market_data = {}
|
284 |
+
return state
|
285 |
+
|
286 |
+
|
287 |
+
async def scraping_agent_node(state: MarketBriefState) -> MarketBriefState:
|
288 |
+
if not state.target_tickers_for_data_fetch:
|
289 |
+
state.warnings.append(
|
290 |
+
"Scraping Agent Node: No target tickers to fetch earnings for. Skipping."
|
291 |
+
)
|
292 |
+
state.filings = {}
|
293 |
+
return state
|
294 |
+
|
295 |
+
async with httpx.AsyncClient() as client:
|
296 |
+
filings_data: Dict[str, List[Dict[str, Any]]] = {}
|
297 |
+
for ticker in state.target_tickers_for_data_fetch:
|
298 |
+
payload = {"ticker": ticker, "filing_type": "earnings_surprise"}
|
299 |
+
try:
|
300 |
+
response_data = await call_agent(
|
301 |
+
client, f"{AGENT_SCRAPING_URL}/get_filings", json_payload=payload
|
302 |
+
)
|
303 |
+
|
304 |
+
if "data" in response_data and isinstance(response_data["data"], list):
|
305 |
+
|
306 |
+
filings_data[ticker] = response_data["data"]
|
307 |
+
logger.info(
|
308 |
+
f"Scraping Agent got {len(response_data['data'])} records for {ticker}."
|
309 |
+
)
|
310 |
+
if not response_data["data"]:
|
311 |
+
logger.info(
|
312 |
+
f"Scraping Agent for {ticker} returned 0 earnings surprise records."
|
313 |
+
)
|
314 |
+
else:
|
315 |
+
filings_data[ticker] = []
|
316 |
+
state.errors.append(
|
317 |
+
f"Scraping agent for {ticker} returned malformed data: {str(response_data)[:100]}"
|
318 |
+
)
|
319 |
+
except HTTPException as e:
|
320 |
+
state.errors.append(
|
321 |
+
f"Scraping Agent Node failed for {ticker}: {e.detail}"
|
322 |
+
)
|
323 |
+
filings_data[ticker] = []
|
324 |
+
state.filings = filings_data
|
325 |
+
return state
|
326 |
+
|
327 |
+
|
328 |
+
async def retriever_agent_node(state: MarketBriefState) -> MarketBriefState:
|
329 |
+
|
330 |
+
async with httpx.AsyncClient() as client:
|
331 |
+
docs_to_index = []
|
332 |
+
if state.filings:
|
333 |
+
for (
|
334 |
+
ticker,
|
335 |
+
records_list,
|
336 |
+
) in state.filings.items():
|
337 |
+
if records_list:
|
338 |
+
doc_content = f"Earnings surprise data for {ticker}:\n" + "\n".join(
|
339 |
+
[
|
340 |
+
f"Date: {r.get('date', 'N/A')}, Symbol: {r.get('symbol', 'N/A')}, "
|
341 |
+
f"Actual: {r.get('actual', 'N/A')}, Estimate: {r.get('estimate', 'N/A')}, "
|
342 |
+
f"Surprise%: {r.get('surprisePercentage', 'N/A')}"
|
343 |
+
for r in records_list
|
344 |
+
]
|
345 |
+
)
|
346 |
+
docs_to_index.append(doc_content)
|
347 |
+
|
348 |
+
if docs_to_index:
|
349 |
+
try:
|
350 |
+
|
351 |
+
pass
|
352 |
+
except Exception as e:
|
353 |
+
state.errors.append(f"Retriever indexing failed: {e}")
|
354 |
+
state.indexed = False
|
355 |
+
else:
|
356 |
+
state.indexed = False
|
357 |
+
logger.info("Retriever: No new documents to index.")
|
358 |
+
|
359 |
+
if state.user_text:
|
360 |
+
try:
|
361 |
+
|
362 |
+
pass
|
363 |
+
except Exception as e:
|
364 |
+
state.errors.append(f"Retriever retrieval failed: {e}")
|
365 |
+
state.retrieved_docs = []
|
366 |
+
else:
|
367 |
+
state.retrieved_docs = []
|
368 |
+
return state
|
369 |
+
|
370 |
+
|
371 |
+
async def analysis_agent_node(state: MarketBriefState) -> MarketBriefState:
|
372 |
+
if not state.market_data and not state.filings:
|
373 |
+
state.warnings.append(
|
374 |
+
"Analysis Agent Node: No market data or filings available. Skipping analysis."
|
375 |
+
)
|
376 |
+
state.analysis = None
|
377 |
+
return state
|
378 |
+
|
379 |
+
async with httpx.AsyncClient() as client:
|
380 |
+
|
381 |
+
nlu_res = state.nlu_results if state.nlu_results else {}
|
382 |
+
region_label = nlu_res.get("region_effective", nlu_res.get("region", "Global"))
|
383 |
+
sector_label = nlu_res.get(
|
384 |
+
"sector_effective", nlu_res.get("sector", "Overall Portfolio")
|
385 |
+
)
|
386 |
+
|
387 |
+
if region_label == "Global" and (
|
388 |
+
sector_label == "Overall Portfolio" or sector_label == "Overall Market"
|
389 |
+
):
|
390 |
+
target_label_for_analysis = "Overall Portfolio"
|
391 |
+
else:
|
392 |
+
target_label_for_analysis = (
|
393 |
+
f"{region_label.replace('USA', 'US')} {sector_label} Stocks".strip()
|
394 |
+
)
|
395 |
+
|
396 |
+
analysis_target_tickers = state.target_tickers_for_data_fetch
|
397 |
+
|
398 |
+
current_portfolio_weights = {
|
399 |
+
ticker: details["weight"] for ticker, details in EXAMPLE_PORTFOLIO.items()
|
400 |
+
}
|
401 |
+
|
402 |
+
payload = {
|
403 |
+
"portfolio": current_portfolio_weights,
|
404 |
+
"market_data": state.market_data if state.market_data else {},
|
405 |
+
"earnings_data": (state.filings if state.filings else {}),
|
406 |
+
"target_tickers": analysis_target_tickers,
|
407 |
+
"target_label": target_label_for_analysis,
|
408 |
+
}
|
409 |
+
try:
|
410 |
+
response_data = await call_agent(
|
411 |
+
client, f"{AGENT_ANALYSIS_URL}/analyze", json_payload=payload
|
412 |
+
)
|
413 |
+
|
414 |
+
state.analysis = response_data
|
415 |
+
logger.info(
|
416 |
+
f"Analysis Agent successful for '{response_data.get('target_label')}'."
|
417 |
+
)
|
418 |
+
except HTTPException as e:
|
419 |
+
state.errors.append(f"Analysis Agent Node failed: {e.detail}")
|
420 |
+
state.analysis = None
|
421 |
+
return state
|
422 |
+
|
423 |
+
|
424 |
+
async def language_agent_node(state: MarketBriefState) -> MarketBriefState:
|
425 |
+
|
426 |
+
async with httpx.AsyncClient() as client:
|
427 |
+
if not state.user_text or "Error:" in state.user_text:
|
428 |
+
state.errors.append("Language Agent: Skipping due to no valid user text.")
|
429 |
+
state.brief = (
|
430 |
+
"I could not understand your query or there was an earlier error."
|
431 |
+
)
|
432 |
+
return state
|
433 |
+
|
434 |
+
analysis_payload_for_llm: Dict[str, Any]
|
435 |
+
if state.analysis and isinstance(state.analysis, dict):
|
436 |
+
|
437 |
+
analysis_payload_for_llm = {
|
438 |
+
"target_label": state.analysis.get("target_label", "the portfolio"),
|
439 |
+
"current_allocation": state.analysis.get("current_allocation", 0.0),
|
440 |
+
"yesterday_allocation": state.analysis.get("yesterday_allocation", 0.0),
|
441 |
+
"allocation_change_percentage_points": state.analysis.get(
|
442 |
+
"allocation_change_percentage_points", 0.0
|
443 |
+
),
|
444 |
+
"earnings_surprises_for_target": state.analysis.get(
|
445 |
+
"earnings_surprises_for_target", []
|
446 |
+
),
|
447 |
+
}
|
448 |
+
else:
|
449 |
+
logger.warning(
|
450 |
+
"Language Agent: Analysis data is missing or not a dict. Using defaults."
|
451 |
+
)
|
452 |
+
state.warnings.append(
|
453 |
+
"Language Agent: Analysis data unavailable, brief will be general."
|
454 |
+
)
|
455 |
+
analysis_payload_for_llm = {
|
456 |
+
"target_label": "the portfolio (analysis data missing)",
|
457 |
+
"current_allocation": 0.0,
|
458 |
+
"yesterday_allocation": 0.0,
|
459 |
+
"allocation_change_percentage_points": 0.0,
|
460 |
+
"earnings_surprises_for_target": [],
|
461 |
+
}
|
462 |
+
|
463 |
+
payload = {
|
464 |
+
"user_query": state.user_text,
|
465 |
+
"analysis": analysis_payload_for_llm,
|
466 |
+
"retrieved_docs": state.retrieved_docs if state.retrieved_docs else [],
|
467 |
+
}
|
468 |
+
try:
|
469 |
+
response_data = await call_agent(
|
470 |
+
client, f"{AGENT_LANGUAGE_URL}/generate_brief", json_payload=payload
|
471 |
+
)
|
472 |
+
state.brief = response_data.get("brief")
|
473 |
+
logger.info(f"Language Agent successful. Brief: {state.brief[:70]}...")
|
474 |
+
except HTTPException as e:
|
475 |
+
state.errors.append(f"Language Agent Node failed: {e.detail}")
|
476 |
+
state.brief = "Sorry, I couldn't generate the brief at this time due to an internal error."
|
477 |
+
return state
|
478 |
+
|
479 |
+
|
480 |
+
async def tts_node(state: MarketBriefState) -> MarketBriefState:
|
481 |
+
|
482 |
+
brief_text_for_tts = state.brief
|
483 |
+
if state.errors and (
|
484 |
+
not state.brief
|
485 |
+
or "sorry" in state.brief.lower()
|
486 |
+
or "error" in state.brief.lower()
|
487 |
+
):
|
488 |
+
|
489 |
+
error_count = len(state.errors)
|
490 |
+
brief_text_for_tts = f"I encountered {error_count} error{'s' if error_count > 1 else ''} while processing your request. Please check the detailed report."
|
491 |
+
logger.warning(
|
492 |
+
f"TTS Node: Generating audio for error summary due to {error_count} errors in state."
|
493 |
+
)
|
494 |
+
elif not state.brief:
|
495 |
+
brief_text_for_tts = "The market brief could not be generated."
|
496 |
+
logger.warning("TTS Node: No brief text available from language agent.")
|
497 |
+
state.warnings.append("TTS Node: No brief content to synthesize.")
|
498 |
+
|
499 |
+
if not brief_text_for_tts:
|
500 |
+
state.audio_output = None
|
501 |
+
return state
|
502 |
+
|
503 |
+
async with httpx.AsyncClient() as client:
|
504 |
+
payload = {"text": brief_text_for_tts, "lang": "en"}
|
505 |
+
try:
|
506 |
+
response_data = await call_agent(
|
507 |
+
client, f"{AGENT_VOICE_URL}/tts", json_payload=payload
|
508 |
+
)
|
509 |
+
if "audio" in response_data and isinstance(response_data["audio"], str):
|
510 |
+
state.audio_output = bytes.fromhex(response_data["audio"])
|
511 |
+
logger.info("TTS successful. Audio bytes received.")
|
512 |
+
else:
|
513 |
+
state.errors.append(
|
514 |
+
f"TTS Agent response missing or invalid 'audio': {str(response_data)[:100]}"
|
515 |
+
)
|
516 |
+
state.audio_output = None
|
517 |
+
except HTTPException as e:
|
518 |
+
state.errors.append(f"TTS Node failed: {e.detail}")
|
519 |
+
state.audio_output = None
|
520 |
+
return state
|
521 |
+
|
522 |
+
|
523 |
+
def build_market_brief_graph():
|
524 |
+
builder = StateGraph(MarketBriefState)
|
525 |
+
builder.add_node("stt", stt_node)
|
526 |
+
builder.add_node("nlu", nlu_node)
|
527 |
+
builder.add_node("api_agent", api_agent_node)
|
528 |
+
builder.add_node("scraping_agent", scraping_agent_node)
|
529 |
+
builder.add_node("retriever_agent", retriever_agent_node)
|
530 |
+
builder.add_node("analysis_agent", analysis_agent_node)
|
531 |
+
builder.add_node("language_agent", language_agent_node)
|
532 |
+
builder.add_node("tts", tts_node)
|
533 |
+
|
534 |
+
builder.set_entry_point("stt")
|
535 |
+
builder.add_edge("stt", "nlu")
|
536 |
+
builder.add_edge("nlu", "api_agent")
|
537 |
+
builder.add_edge("api_agent", "scraping_agent")
|
538 |
+
builder.add_edge("scraping_agent", "retriever_agent")
|
539 |
+
builder.add_edge("retriever_agent", "analysis_agent")
|
540 |
+
builder.add_edge("analysis_agent", "language_agent")
|
541 |
+
builder.add_edge("language_agent", "tts")
|
542 |
+
builder.add_edge("tts", END)
|
543 |
+
return builder.compile()
|
544 |
+
|
545 |
+
|
546 |
+
graph = build_market_brief_graph()
|
547 |
+
|
548 |
+
|
549 |
+
@app.post("/market_brief")
|
550 |
+
async def market_brief(audio: UploadFile = File(...)):
|
551 |
+
|
552 |
+
logger.info("Received request to /market_brief")
|
553 |
+
if not audio.content_type or not audio.content_type.startswith("audio/"):
|
554 |
+
raise HTTPException(
|
555 |
+
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
556 |
+
detail="Invalid file type.",
|
557 |
+
)
|
558 |
+
|
559 |
+
current_run_state = MarketBriefState()
|
560 |
+
try:
|
561 |
+
current_run_state.audio_input = await audio.read()
|
562 |
+
except Exception as e:
|
563 |
+
raise HTTPException(
|
564 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
565 |
+
detail=f"Failed to read audio: {e}",
|
566 |
+
)
|
567 |
+
|
568 |
+
processed_state: MarketBriefState = current_run_state
|
569 |
+
|
570 |
+
try:
|
571 |
+
logger.info("Invoking LangGraph workflow...")
|
572 |
+
|
573 |
+
initial_state_dict = current_run_state.model_dump(exclude_none=True)
|
574 |
+
invocation_result = await graph.ainvoke(initial_state_dict)
|
575 |
+
|
576 |
+
if isinstance(invocation_result, dict):
|
577 |
+
|
578 |
+
processed_state = MarketBriefState(**invocation_result)
|
579 |
+
logger.info("LangGraph execution finished. State updated.")
|
580 |
+
else:
|
581 |
+
logger.error(
|
582 |
+
f"LangGraph ainvoke returned unexpected type: {type(invocation_result)}. Using partially updated state."
|
583 |
+
)
|
584 |
+
|
585 |
+
processed_state.errors.append(
|
586 |
+
f"Internal graph error: result type {type(invocation_result)}"
|
587 |
+
)
|
588 |
+
|
589 |
+
except HTTPException as e:
|
590 |
+
logger.error(
|
591 |
+
f"Graph execution stopped due to HTTPException from an agent: {e.detail}"
|
592 |
+
)
|
593 |
+
processed_state.errors.append(f"Agent call failed: {e.detail}")
|
594 |
+
except Exception as e:
|
595 |
+
error_msg = f"An unexpected error occurred during graph execution: {e}"
|
596 |
+
logger.error(error_msg, exc_info=True)
|
597 |
+
processed_state.errors.append(error_msg)
|
598 |
+
|
599 |
+
response_payload = {
|
600 |
+
"transcript": processed_state.user_text,
|
601 |
+
"brief": processed_state.brief,
|
602 |
+
"audio": (
|
603 |
+
processed_state.audio_output.hex() if processed_state.audio_output else None
|
604 |
+
),
|
605 |
+
"errors": processed_state.errors,
|
606 |
+
"warnings": processed_state.warnings,
|
607 |
+
"status": "success" if not processed_state.errors else "failed",
|
608 |
+
"message": "Market brief process completed."
|
609 |
+
+ (" With errors." if processed_state.errors else " Successfully."),
|
610 |
+
"nlu_detected": processed_state.nlu_results,
|
611 |
+
"analysis_details": processed_state.analysis,
|
612 |
+
}
|
613 |
+
logger.info(
|
614 |
+
f"Request finished. Status: {response_payload['status']}. Errors: {len(response_payload['errors'])}. Warnings: {len(response_payload['warnings'])}."
|
615 |
+
)
|
616 |
+
return response_payload
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn
|
3 |
+
pydantic
|
4 |
+
requests
|
5 |
+
python-dotenv
|
6 |
+
numpy
|
7 |
+
faiss-cpu
|
8 |
+
sentence-transformers
|
9 |
+
langchain>=0.1.0
|
10 |
+
langchain-core
|
11 |
+
langchain-community
|
12 |
+
langchain-openai
|
13 |
+
gtts
|
14 |
+
faster-whisper
|
15 |
+
python-multipart
|
16 |
+
langgraph
|
17 |
+
streamlit
|
18 |
+
streamlit-mic-recorder
|
streamlit/app.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import httpx
|
3 |
+
import os
|
4 |
+
import io
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
import logging
|
7 |
+
import asyncio
|
8 |
+
from streamlit_mic_recorder import mic_recorder
|
9 |
+
|
10 |
+
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
13 |
+
)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
ORCHESTRATOR_URL = os.getenv("ORCHESTRATOR_URL")
|
18 |
+
|
19 |
+
|
20 |
+
if "processing_state" not in st.session_state:
|
21 |
+
st.session_state.processing_state = "initial"
|
22 |
+
if "orchestrator_response" not in st.session_state:
|
23 |
+
st.session_state.orchestrator_response = None
|
24 |
+
if "audio_bytes_input" not in st.session_state:
|
25 |
+
st.session_state.audio_bytes_input = None
|
26 |
+
if "audio_filename" not in st.session_state:
|
27 |
+
st.session_state.audio_filename = None
|
28 |
+
if "audio_filetype" not in st.session_state:
|
29 |
+
st.session_state.audio_filetype = None
|
30 |
+
if "last_audio_source" not in st.session_state:
|
31 |
+
st.session_state.last_audio_source = None
|
32 |
+
if "current_recording_id" not in st.session_state:
|
33 |
+
st.session_state.current_recording_id = None
|
34 |
+
|
35 |
+
|
36 |
+
async def call_orchestrator(audio_bytes: bytes, filename: str, content_type: str):
|
37 |
+
|
38 |
+
url = f"{ORCHESTRATOR_URL}/market_brief"
|
39 |
+
files = {"audio": (filename, audio_bytes, content_type)}
|
40 |
+
logger.info(
|
41 |
+
f"Calling orchestrator at {url} with audio file: {filename} ({content_type})"
|
42 |
+
)
|
43 |
+
try:
|
44 |
+
async with httpx.AsyncClient() as client:
|
45 |
+
response = await client.post(url, files=files, timeout=180.0)
|
46 |
+
response.raise_for_status()
|
47 |
+
logger.info(f"Orchestrator returned status {response.status_code}.")
|
48 |
+
return response.json()
|
49 |
+
except httpx.RequestError as e:
|
50 |
+
error_msg = f"HTTP Request failed: {e}"
|
51 |
+
logger.error(error_msg)
|
52 |
+
return {
|
53 |
+
"status": "error",
|
54 |
+
"message": "Error communicating with orchestrator.",
|
55 |
+
"errors": [error_msg],
|
56 |
+
"transcript": None,
|
57 |
+
"brief": None,
|
58 |
+
"audio": None,
|
59 |
+
}
|
60 |
+
except Exception as e:
|
61 |
+
error_msg = f"An unexpected error occurred: {e}"
|
62 |
+
logger.error(error_msg)
|
63 |
+
return {
|
64 |
+
"status": "error",
|
65 |
+
"message": "An unexpected error occurred.",
|
66 |
+
"errors": [error_msg],
|
67 |
+
"transcript": None,
|
68 |
+
"brief": None,
|
69 |
+
"audio": None,
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
st.set_page_config(layout="wide")
|
74 |
+
st.title("📈 AI Financial Assistant - Morning Market Brief")
|
75 |
+
st.markdown(
|
76 |
+
"Ask your query verbally (e.g., 'What's our risk exposure in Asia tech stocks today, and highlight any earnings surprises?') "
|
77 |
+
"or upload an audio file."
|
78 |
+
)
|
79 |
+
|
80 |
+
input_method = st.radio(
|
81 |
+
"Choose input method:",
|
82 |
+
("Record Audio", "Upload File"),
|
83 |
+
horizontal=True,
|
84 |
+
index=0,
|
85 |
+
key="input_method_radio",
|
86 |
+
)
|
87 |
+
|
88 |
+
audio_data_ready = False
|
89 |
+
|
90 |
+
|
91 |
+
if st.session_state.audio_bytes_input is not None:
|
92 |
+
audio_data_ready = True
|
93 |
+
|
94 |
+
|
95 |
+
if input_method == "Record Audio":
|
96 |
+
st.subheader("Record Your Query")
|
97 |
+
|
98 |
+
if st.session_state.last_audio_source == "uploader":
|
99 |
+
st.session_state.audio_bytes_input = None
|
100 |
+
st.session_state.audio_filename = None
|
101 |
+
st.session_state.audio_filetype = None
|
102 |
+
st.session_state.last_audio_source = "recorder"
|
103 |
+
audio_data_ready = False
|
104 |
+
|
105 |
+
audio_info = mic_recorder(
|
106 |
+
start_prompt="⏺️ Start Recording",
|
107 |
+
stop_prompt="⏹️ Stop Recording",
|
108 |
+
just_once=False,
|
109 |
+
use_container_width=True,
|
110 |
+
format="wav",
|
111 |
+
key="mic_recorder_widget",
|
112 |
+
)
|
113 |
+
|
114 |
+
if audio_info and audio_info.get("bytes"):
|
115 |
+
|
116 |
+
if st.session_state.current_recording_id != audio_info.get("id"):
|
117 |
+
st.session_state.current_recording_id = audio_info.get("id")
|
118 |
+
st.success("Recording complete! Click 'Generate Market Brief' below.")
|
119 |
+
st.session_state.audio_bytes_input = audio_info["bytes"]
|
120 |
+
st.session_state.audio_filename = f"live_recording_{audio_info['id']}.wav"
|
121 |
+
st.session_state.audio_filetype = "audio/wav"
|
122 |
+
st.session_state.last_audio_source = "recorder"
|
123 |
+
audio_data_ready = True
|
124 |
+
st.session_state.processing_state = "initial"
|
125 |
+
st.session_state.orchestrator_response = None
|
126 |
+
st.audio(audio_info["bytes"])
|
127 |
+
|
128 |
+
elif st.session_state.audio_bytes_input:
|
129 |
+
audio_data_ready = True
|
130 |
+
st.audio(st.session_state.audio_bytes_input)
|
131 |
+
|
132 |
+
elif (
|
133 |
+
st.session_state.last_audio_source == "recorder"
|
134 |
+
and st.session_state.audio_bytes_input
|
135 |
+
):
|
136 |
+
st.markdown("Using last recording:")
|
137 |
+
st.audio(st.session_state.audio_bytes_input)
|
138 |
+
audio_data_ready = True
|
139 |
+
|
140 |
+
|
141 |
+
elif input_method == "Upload File":
|
142 |
+
st.subheader("Upload Audio File")
|
143 |
+
|
144 |
+
if st.session_state.last_audio_source == "recorder":
|
145 |
+
st.session_state.audio_bytes_input = None
|
146 |
+
st.session_state.audio_filename = None
|
147 |
+
st.session_state.audio_filetype = None
|
148 |
+
st.session_state.last_audio_source = "uploader"
|
149 |
+
st.session_state.current_recording_id = None
|
150 |
+
audio_data_ready = False
|
151 |
+
|
152 |
+
if "uploaded_file_state" not in st.session_state:
|
153 |
+
st.session_state.uploaded_file_state = None
|
154 |
+
|
155 |
+
uploaded_file = st.file_uploader(
|
156 |
+
"Select Audio File",
|
157 |
+
type=["wav", "mp3", "m4a", "ogg", "flac"],
|
158 |
+
key="file_uploader_key",
|
159 |
+
)
|
160 |
+
|
161 |
+
if uploaded_file is not None:
|
162 |
+
if st.session_state.uploaded_file_state != uploaded_file:
|
163 |
+
st.session_state.uploaded_file_state = uploaded_file
|
164 |
+
st.session_state.audio_bytes_input = uploaded_file.getvalue()
|
165 |
+
st.session_state.audio_filename = uploaded_file.name
|
166 |
+
st.session_state.audio_filetype = uploaded_file.type
|
167 |
+
st.session_state.last_audio_source = "uploader"
|
168 |
+
audio_data_ready = True
|
169 |
+
st.session_state.processing_state = "initial"
|
170 |
+
st.session_state.orchestrator_response = None
|
171 |
+
st.success(f"File '{uploaded_file.name}' ready.")
|
172 |
+
st.audio(
|
173 |
+
st.session_state.audio_bytes_input,
|
174 |
+
format=st.session_state.audio_filetype,
|
175 |
+
)
|
176 |
+
elif st.session_state.audio_bytes_input:
|
177 |
+
audio_data_ready = True
|
178 |
+
st.audio(
|
179 |
+
st.session_state.audio_bytes_input,
|
180 |
+
format=st.session_state.audio_filetype,
|
181 |
+
)
|
182 |
+
|
183 |
+
elif (
|
184 |
+
st.session_state.last_audio_source == "uploader"
|
185 |
+
and st.session_state.audio_bytes_input
|
186 |
+
):
|
187 |
+
st.markdown("Using last uploaded file:")
|
188 |
+
st.audio(
|
189 |
+
st.session_state.audio_bytes_input, format=st.session_state.audio_filetype
|
190 |
+
)
|
191 |
+
audio_data_ready = True
|
192 |
+
|
193 |
+
|
194 |
+
st.divider()
|
195 |
+
button_disabled = (
|
196 |
+
not audio_data_ready or st.session_state.processing_state == "processing"
|
197 |
+
)
|
198 |
+
|
199 |
+
if st.button(
|
200 |
+
"Generate Market Brief",
|
201 |
+
disabled=button_disabled,
|
202 |
+
type="primary",
|
203 |
+
use_container_width=True,
|
204 |
+
key="generate_button",
|
205 |
+
):
|
206 |
+
if st.session_state.audio_bytes_input:
|
207 |
+
st.session_state.processing_state = "processing"
|
208 |
+
st.session_state.orchestrator_response = None
|
209 |
+
logger.info(
|
210 |
+
f"Generate Market Brief button clicked. Source: {st.session_state.last_audio_source}, Filename: {st.session_state.audio_filename}"
|
211 |
+
)
|
212 |
+
st.rerun()
|
213 |
+
else:
|
214 |
+
st.warning("Please record or upload an audio query first.")
|
215 |
+
|
216 |
+
|
217 |
+
if st.session_state.processing_state == "processing":
|
218 |
+
if (
|
219 |
+
st.session_state.audio_bytes_input
|
220 |
+
and st.session_state.audio_filename
|
221 |
+
and st.session_state.audio_filetype
|
222 |
+
):
|
223 |
+
with st.spinner("Processing your request... This may take a moment. 🤖"):
|
224 |
+
|
225 |
+
logger.info(
|
226 |
+
f"Calling orchestrator with filename: {st.session_state.audio_filename}, type: {st.session_state.audio_filetype}, bytes: {len(st.session_state.audio_bytes_input)}"
|
227 |
+
)
|
228 |
+
try:
|
229 |
+
response = asyncio.run(
|
230 |
+
call_orchestrator(
|
231 |
+
st.session_state.audio_bytes_input,
|
232 |
+
st.session_state.audio_filename,
|
233 |
+
st.session_state.audio_filetype,
|
234 |
+
)
|
235 |
+
)
|
236 |
+
st.session_state.orchestrator_response = response
|
237 |
+
|
238 |
+
is_successful_response = True
|
239 |
+
if not response:
|
240 |
+
is_successful_response = False
|
241 |
+
elif (
|
242 |
+
response.get("status") == "error"
|
243 |
+
or response.get("status") == "failed"
|
244 |
+
):
|
245 |
+
is_successful_response = False
|
246 |
+
elif response.get("errors") and len(response.get("errors")) > 0:
|
247 |
+
is_successful_response = False
|
248 |
+
|
249 |
+
st.session_state.processing_state = (
|
250 |
+
"completed" if is_successful_response else "error"
|
251 |
+
)
|
252 |
+
|
253 |
+
except Exception as e:
|
254 |
+
logger.error(
|
255 |
+
f"Error during orchestrator call in Streamlit: {e}", exc_info=True
|
256 |
+
)
|
257 |
+
st.session_state.orchestrator_response = {
|
258 |
+
"status": "error",
|
259 |
+
"message": f"Streamlit failed to call orchestrator: {str(e)}",
|
260 |
+
"errors": [str(e)],
|
261 |
+
"transcript": None,
|
262 |
+
"brief": None,
|
263 |
+
"audio": None,
|
264 |
+
}
|
265 |
+
st.session_state.processing_state = "error"
|
266 |
+
st.rerun()
|
267 |
+
else:
|
268 |
+
st.error("Audio data is missing for processing. Please record or upload again.")
|
269 |
+
st.session_state.processing_state = "initial"
|
270 |
+
|
271 |
+
|
272 |
+
if st.session_state.processing_state in ["completed", "error"]:
|
273 |
+
|
274 |
+
response = st.session_state.orchestrator_response
|
275 |
+
st.subheader("📝 Results")
|
276 |
+
|
277 |
+
if response is None:
|
278 |
+
st.error("No response received from the orchestrator.")
|
279 |
+
|
280 |
+
elif (
|
281 |
+
response.get("status") == "failed"
|
282 |
+
or response.get("status") == "error"
|
283 |
+
or (response.get("errors") and len(response.get("errors")) > 0)
|
284 |
+
):
|
285 |
+
st.error(
|
286 |
+
f"Workflow {response.get('status', 'failed')}: {response.get('message', 'Check errors below.')}"
|
287 |
+
)
|
288 |
+
if response.get("errors"):
|
289 |
+
st.warning("Details of Errors:")
|
290 |
+
for i, err in enumerate(response["errors"]):
|
291 |
+
st.markdown(f"`Error {i+1}`: {err}")
|
292 |
+
if response.get("warnings"):
|
293 |
+
st.warning("Details of Warnings:")
|
294 |
+
for i, warn in enumerate(response["warnings"]):
|
295 |
+
st.markdown(f"`Warning {i+1}`: {warn}")
|
296 |
+
|
297 |
+
if response.get("transcript"):
|
298 |
+
st.markdown("---")
|
299 |
+
st.markdown("Transcript (despite errors):")
|
300 |
+
st.caption(response.get("transcript"))
|
301 |
+
if response.get("brief"):
|
302 |
+
st.markdown("---")
|
303 |
+
st.markdown("Generated Brief (despite errors):")
|
304 |
+
st.caption(response.get("brief"))
|
305 |
+
else:
|
306 |
+
st.success(response.get("message", "Market brief generated successfully!"))
|
307 |
+
if response.get("transcript"):
|
308 |
+
st.markdown("---")
|
309 |
+
st.markdown("Your Query (Transcript):")
|
310 |
+
st.caption(response.get("transcript"))
|
311 |
+
else:
|
312 |
+
st.info("Transcript not available.")
|
313 |
+
|
314 |
+
if response.get("brief"):
|
315 |
+
st.markdown("---")
|
316 |
+
st.markdown("Generated Brief:")
|
317 |
+
st.write(response.get("brief"))
|
318 |
+
else:
|
319 |
+
st.info("Brief text not available.")
|
320 |
+
|
321 |
+
audio_hex = response.get("audio")
|
322 |
+
if audio_hex:
|
323 |
+
st.markdown("---")
|
324 |
+
st.markdown("Audio Brief:")
|
325 |
+
try:
|
326 |
+
if not isinstance(audio_hex, str) or not all(
|
327 |
+
c in "0123456789abcdefABCDEF" for c in audio_hex
|
328 |
+
):
|
329 |
+
raise ValueError("Invalid hex string for audio.")
|
330 |
+
audio_bytes_output = bytes.fromhex(audio_hex)
|
331 |
+
st.audio(audio_bytes_output, format="audio/mpeg")
|
332 |
+
except ValueError as ve:
|
333 |
+
st.error(f"⚠️ Failed to decode audio data: {ve}")
|
334 |
+
except Exception as e:
|
335 |
+
st.error(f"⚠️ Failed to play audio: {e}")
|
336 |
+
else:
|
337 |
+
st.info("Audio brief not available.")
|
338 |
+
|
339 |
+
if response.get("warnings"):
|
340 |
+
st.markdown("---")
|
341 |
+
st.warning("Process Warnings:")
|
342 |
+
for i, warn in enumerate(response["warnings"]):
|
343 |
+
st.markdown(f"`Warning {i+1}`: {warn}")
|