Shahabmoin's picture
Update app.py
7bf1eec verified
raw
history blame
2.35 kB
import os
import numpy as np
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from groq import Groq
# Load pre-trained Sentence Transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Initialize Groq API client
GROQ_API_KEY = "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"
client = Groq(api_key=GROQ_API_KEY)
# Generate embeddings using Sentence Transformers
def generate_embeddings(text):
return model.encode(text)
# Build FAISS index for retrieval
def build_faiss_index(data):
index = faiss.IndexFlatL2(384) # 384-dimensional embeddings for MiniLM
embeddings = [generate_embeddings(row.to_string()) for _, row in data.iterrows()]
embeddings = np.array(embeddings).astype("float32")
index.add(embeddings)
return index, embeddings
# Query FAISS index
def query_index(query, data, index):
query_embedding = generate_embeddings(query).astype("float32")
distances, indices = index.search(np.array([query_embedding]), k=5)
results = data.iloc[indices[0]]
return results
# Generate a detailed report using Groq's generative model
def generate_report_with_groq(query, results):
input_text = f"Based on the query '{query}', the following insights are generated:\n\n{results.to_string(index=False)}"
response = client.chat.completions.create(
messages=[{"role": "user", "content": input_text}],
model="llama3-8b-8192",
stream=False
)
return response.choices[0].message.content
# Main function to execute the workflow
if __name__ == "__main__":
# Load dataset
csv_path = "energy_usage_data.csv" # Ensure this CSV is uploaded to your working directory
data = pd.read_csv(csv_path)
# Preprocess data (if needed)
data.fillna("", inplace=True)
# Build FAISS index
print("Building FAISS index...")
index, embeddings = build_faiss_index(data)
# User query
query = "Show households with high energy usage in the North region"
print(f"User Query: {query}")
# Query FAISS index
print("Retrieving relevant data...")
results = query_index(query, data, index)
# Generate report
print("Generating report using Groq API...")
report = generate_report_with_groq(query, results)
# Output the report
print("Generated Report:\n")
print(report)