svsaurav95 commited on
Commit
1be9e52
·
verified ·
1 Parent(s): c725c88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -176
app.py CHANGED
@@ -1,178 +1,93 @@
1
  import streamlit as st
2
- import pymupdf
3
- import re
4
- import traceback
5
- import faiss
6
- import numpy as np
7
- import requests
8
- from rank_bm25 import BM25Okapi
9
- from sentence_transformers import SentenceTransformer
10
- from langchain.text_splitter import RecursiveCharacterTextSplitter
11
- from langchain_groq import ChatGroq
12
  import torch
13
- import os
14
-
15
- os.environ["STREAMLIT_WATCHDOG_TYPE"] = "none"
16
-
17
- st.set_page_config(page_title="Financial Insights Chatbot", page_icon="📊", layout="wide")
18
-
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
20
-
21
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
22
- ALPHA_VANTAGE_API_KEY = os.getenv("ALPHA_VANTAGE_API_KEY")
23
-
24
- try:
25
- llm = ChatGroq(temperature=0, model="llama3-70b-8192", api_key=GROQ_API_KEY)
26
- st.success(" LLM initialized successfully. Using llama3-70b-8192")
27
- except Exception as e:
28
- st.error("❌ Failed to initialize Groq LLM.")
29
- traceback.print_exc()
30
-
31
- embedding_model = SentenceTransformer("baconnier/Finance2_embedding_small_en-V1.5", device=device)
32
-
33
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
34
-
35
- def fetch_financial_data(company_ticker):
36
- if not company_ticker:
37
- return "No ticker symbol provided. Please enter a valid company ticker."
38
-
39
- try:
40
- overview_url = f"https://www.alphavantage.co/query?function=OVERVIEW&symbol={company_ticker}&apikey={ALPHA_VANTAGE_API_KEY}"
41
- overview_response = requests.get(overview_url)
42
-
43
- if overview_response.status_code == 200:
44
- overview_data = overview_response.json()
45
- market_cap = overview_data.get("MarketCapitalization", "N/A")
46
- else:
47
- return "Error fetching company overview."
48
-
49
- income_url = f"https://www.alphavantage.co/query?function=INCOME_STATEMENT&symbol={company_ticker}&apikey={ALPHA_VANTAGE_API_KEY}"
50
- income_response = requests.get(income_url)
51
-
52
- if income_response.status_code == 200:
53
- income_data = income_response.json()
54
- annual_reports = income_data.get("annualReports", [])
55
- revenue = annual_reports[0].get("totalRevenue", "N/A") if annual_reports else "N/A"
56
- else:
57
- return "Error fetching income statement."
58
-
59
- return f"Market Cap: ${market_cap}\nTotal Revenue: ${revenue}"
60
-
61
- except Exception as e:
62
- traceback.print_exc()
63
- return "Error fetching financial data."
64
-
65
- def extract_and_embed_text(pdf_file):
66
- """Processes PDFs and generates embeddings with GPU acceleration using pymupdf."""
67
- try:
68
- docs, tokenized_texts = [], []
69
-
70
- with pymupdf.open(stream=pdf_file.read(), filetype="pdf") as doc:
71
- full_text = "\n".join(page.get_text("text") for page in doc)
72
- chunks = text_splitter.split_text(full_text)
73
- for chunk in chunks:
74
- docs.append(chunk)
75
- tokenized_texts.append(chunk.split())
76
-
77
- embeddings = embedding_model.encode(docs, batch_size=64, convert_to_numpy=True, normalize_embeddings=True)
78
-
79
- embedding_dim = embeddings.shape[1]
80
- index = faiss.IndexHNSWFlat(embedding_dim, 32)
81
- index.add(embeddings)
82
-
83
- bm25 = BM25Okapi(tokenized_texts)
84
-
85
- return docs, embeddings, index, bm25
86
- except Exception as e:
87
- traceback.print_exc()
88
- return [], [], None, None
89
-
90
- def retrieve_relevant_docs(user_query, docs, index, bm25):
91
- """Hybrid search using FAISS cosine similarity & BM25 keyword retrieval."""
92
- query_embedding = embedding_model.encode(user_query, convert_to_numpy=True, normalize_embeddings=True)
93
- _, faiss_indices = index.search(np.array([query_embedding]), 8)
94
- bm25_scores = bm25.get_scores(user_query.split())
95
- bm25_indices = np.argsort(bm25_scores)[::-1][:8]
96
- combined_indices = list(set(faiss_indices[0]) | set(bm25_indices))
97
-
98
- return [docs[i] for i in combined_indices[:3]]
99
-
100
- def generate_response(user_query, pdf_ticker, ai_ticker, mode, uploaded_file):
101
- try:
102
- if mode == "📄 PDF Upload Mode":
103
- docs, embeddings, index, bm25 = extract_and_embed_text(uploaded_file)
104
- if not docs:
105
- return "❌ Error extracting text from PDF."
106
-
107
- retrieved_docs = retrieve_relevant_docs(user_query, docs, index, bm25)
108
- context = "\n\n".join(retrieved_docs)
109
-
110
- # Avoid using 'None' in prompt
111
- prompt = f"Based on the uploaded financial report, answer the following query:\n{user_query}\n\nRelevant context:\n{context}"
112
-
113
-
114
- elif mode == "🌍 Live Data Mode":
115
- financial_info = fetch_financial_data(ai_ticker)
116
- prompt = f"Analyze the financial status of {ai_ticker} based on:\n{financial_info}\n\nUser Query: {user_query}"
117
- else:
118
- return "Invalid mode selected."
119
-
120
- response = llm.invoke(prompt)
121
- return response.content
122
- except Exception as e:
123
- traceback.print_exc()
124
- return "Error generating response."
125
-
126
- st.markdown(
127
- "<h1 style='text-align: center; color: #4CAF50;'> FinQuery RAG Chatbot</h1>",
128
- unsafe_allow_html=True
129
- )
130
- st.markdown(
131
- "<h5 style='text-align: center; color: #666;'>Analyze financial reports or fetch live financial data effortlessly!</h5>",
132
- unsafe_allow_html=True
133
- )
134
-
135
- col1, col2 = st.columns(2)
136
-
137
- with col1:
138
- st.markdown("### 🏢 **Choose Your Analysis Mode**")
139
- mode = st.radio("", ["📄 PDF Upload Mode", "🌍 Live Data Mode"], horizontal=True)
140
-
141
- with col2:
142
- st.markdown("### **Enter Your Query**")
143
- user_query = st.text_input("💬 What financial insights are you looking for?")
144
-
145
- st.markdown("---")
146
- uploaded_file, company_ticker = None, None
147
-
148
- if mode == "📄 PDF Upload Mode":
149
- st.markdown("### 📂 Upload Your Financial Report")
150
- uploaded_file = st.file_uploader("🔼 Upload PDF Report", type=["pdf"])
151
- company_ticker = None
152
-
153
- else:
154
- st.markdown("### 🌍 Live Market Data")
155
- company_ticker = st.text_input("🏢 Enter Company Ticker Symbol", placeholder="e.g., AAPL, MSFT")
156
- uploaded_file = None
157
-
158
- # 🎯 Submit Button
159
- if st.button("Analyze Now"):
160
- if mode == "📄 PDF Upload Mode" and not uploaded_file:
161
- st.error("❌ Please upload a PDF file.")
162
- elif mode == "🌍 Live Data Mode" and not company_ticker:
163
- st.error("❌ Please enter a valid company ticker symbol.")
164
- else:
165
- with st.spinner(" Your Query is Processing, this can take up to 5 - 7 minutes ⏳"):
166
- if mode == "📄 PDF Upload Mode":
167
- response = generate_response(user_query, company_ticker, None, mode, uploaded_file)
168
- else:
169
- response = generate_response(user_query, None, company_ticker, mode, uploaded_file)
170
-
171
- st.markdown("---")
172
- st.markdown("<h3 style='color: #4CAF50;'>💡 AI Response</h3>", unsafe_allow_html=True)
173
- st.write(response)
174
-
175
-
176
- # 📌 Footer
177
- st.markdown("---")
178
-
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
+ import torch.nn as nn
4
+ import timm
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+ import io
9
+
10
+ # Hide Streamlit warnings and UI elements
11
+ st.set_page_config(layout="wide")
12
+ st.markdown("""
13
+ <style>
14
+ footer {visibility: hidden;}
15
+ </style>
16
+ """, unsafe_allow_html=True)
17
+
18
+ # === Model Definition ===
19
+ class MobileViTSegmentation(nn.Module):
20
+ def __init__(self, encoder_name='mobilevit_s', pretrained=False):
21
+ super().__init__()
22
+ self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
23
+ self.encoder_channels = self.backbone.feature_info.channels()
24
+
25
+ self.decoder = nn.Sequential(
26
+ nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
27
+ nn.Upsample(scale_factor=2, mode='bilinear'),
28
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
29
+ nn.Upsample(scale_factor=2, mode='bilinear'),
30
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
31
+ nn.Upsample(scale_factor=2, mode='bilinear'),
32
+ nn.Conv2d(32, 1, kernel_size=1),
33
+ nn.Sigmoid()
34
+ )
35
+
36
+ def forward(self, x):
37
+ feats = self.backbone(x)
38
+ out = self.decoder(feats[-1])
39
+ out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
40
+ return out
41
+
42
+ # === Load Model ===
43
+ @st.cache_resource
44
+ def load_model():
45
+ model = MobileViTSegmentation()
46
+ state_dict = torch.load("mobilevit_teeth_segmentation.pth", map_location="cpu")
47
+ model.load_state_dict(state_dict)
48
+ model.eval()
49
+ return model
50
+
51
+ model = load_model()
52
+
53
+ # === Preprocessing ===
54
+ def preprocess_image(image: Image.Image):
55
+ image = image.convert("RGB").resize((256, 256))
56
+ arr = np.array(image).astype(np.float32) / 255.0
57
+ arr = np.transpose(arr, (2, 0, 1)) # HWC → CHW
58
+ tensor = torch.tensor(arr).unsqueeze(0) # Add batch dim
59
+ return tensor
60
+
61
+ # === Postprocessing: Overlay Mask ===
62
+ def overlay_mask(image_pil, mask_tensor, threshold=0.7):
63
+ image = np.array(image_pil.resize((256, 256)))
64
+ mask = mask_tensor.squeeze().detach().numpy()
65
+ mask_bin = (mask > threshold).astype(np.uint8) * 255
66
+
67
+ mask_color = np.zeros_like(image)
68
+ mask_color[..., 2] = mask_bin # Blue mask
69
+
70
+ overlayed = cv2.addWeighted(image, 1.0, mask_color, 0.5, 0)
71
+ return overlayed
72
+
73
+ # === UI ===
74
+ st.title("🦷 Tooth Segmentation with MobileViT")
75
+ st.write("Upload an image to segment the **visible teeth area** using a lightweight MobileViT segmentation model.")
76
+
77
+ uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
78
+
79
+ if uploaded_file:
80
+ image = Image.open(uploaded_file)
81
+ tensor = preprocess_image(image)
82
+
83
+ with st.spinner("Segmenting..."):
84
+ with torch.no_grad():
85
+ pred = model(tensor)[0]
86
+
87
+ overlayed_img = overlay_mask(image, pred)
88
+
89
+ col1, col2 = st.columns(2)
90
+ with col1:
91
+ st.image(image, caption="Original Image", use_container_width=True)
92
+ with col2:
93
+ st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True)