|
import os |
|
import numpy as np |
|
import pandas as pd |
|
import faiss |
|
from sentence_transformers import SentenceTransformer |
|
from groq import Groq |
|
|
|
|
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
GROQ_API_KEY = "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" |
|
client = Groq(api_key=GROQ_API_KEY) |
|
|
|
|
|
def generate_embeddings(text): |
|
return model.encode(text) |
|
|
|
|
|
def build_faiss_index(data): |
|
index = faiss.IndexFlatL2(384) |
|
embeddings = [generate_embeddings(row.to_string()) for _, row in data.iterrows()] |
|
embeddings = np.array(embeddings).astype("float32") |
|
index.add(embeddings) |
|
return index, embeddings |
|
|
|
|
|
def query_index(query, data, index): |
|
query_embedding = generate_embeddings(query).astype("float32") |
|
distances, indices = index.search(np.array([query_embedding]), k=5) |
|
results = data.iloc[indices[0]] |
|
return results |
|
|
|
|
|
def generate_report_with_groq(query, results): |
|
input_text = f"Based on the query '{query}', the following insights are generated:\n\n{results.to_string(index=False)}" |
|
response = client.chat.completions.create( |
|
messages=[{"role": "user", "content": input_text}], |
|
model="llama3-8b-8192", |
|
stream=False |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
csv_path = "energy_usage_data.csv" |
|
data = pd.read_csv(csv_path) |
|
|
|
|
|
data.fillna("", inplace=True) |
|
|
|
|
|
print("Building FAISS index...") |
|
index, embeddings = build_faiss_index(data) |
|
|
|
|
|
query = "Show households with high energy usage in the North region" |
|
print(f"User Query: {query}") |
|
|
|
|
|
print("Retrieving relevant data...") |
|
results = query_index(query, data, index) |
|
|
|
|
|
print("Generating report using Groq API...") |
|
report = generate_report_with_groq(query, results) |
|
|
|
|
|
print("Generated Report:\n") |
|
print(report) |
|
|