DrishtiSharma commited on
Commit
5803fa4
Β·
verified Β·
1 Parent(s): 592efed

Delete interim_radio.py

Browse files
Files changed (1) hide show
  1. interim_radio.py +0 -161
interim_radio.py DELETED
@@ -1,161 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import sqlite3
4
- import os
5
- import json
6
- from pathlib import Path
7
- from datetime import datetime, timezone
8
- from crewai import Agent, Crew, Process, Task
9
- from crewai_tools import tool
10
- from langchain_groq import ChatGroq
11
- from langchain.schema.output import LLMResult
12
- from langchain_core.callbacks.base import BaseCallbackHandler
13
- from langchain_community.tools.sql_database.tool import (
14
- InfoSQLDatabaseTool,
15
- ListSQLDatabaseTool,
16
- QuerySQLCheckerTool,
17
- QuerySQLDataBaseTool,
18
- )
19
- from langchain_community.utilities.sql_database import SQLDatabase
20
- from datasets import load_dataset
21
- import tempfile
22
-
23
- # API Key
24
- os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
-
26
- # Initialize LLM
27
- class LLMCallbackHandler(BaseCallbackHandler):
28
- def __init__(self, log_path: Path):
29
- self.log_path = log_path
30
-
31
- def on_llm_start(self, serialized, prompts, **kwargs):
32
- with self.log_path.open("a", encoding="utf-8") as file:
33
- file.write(json.dumps({"event": "llm_start", "text": prompts[0], "timestamp": datetime.now().isoformat()}) + "\n")
34
-
35
- def on_llm_end(self, response: LLMResult, **kwargs):
36
- generation = response.generations[-1][-1].message.content
37
- with self.log_path.open("a", encoding="utf-8") as file:
38
- file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
-
40
- llm = ChatGroq(
41
- temperature=0,
42
- model_name="groq/llama-3.3-70b-versatile",
43
- max_tokens=1024,
44
- callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
- )
46
-
47
- st.title("SQL-RAG Using CrewAI πŸš€")
48
- st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
49
-
50
- # Initialize session state for data persistence
51
- if "df" not in st.session_state:
52
- st.session_state.df = None
53
-
54
- # Dataset Input
55
- input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
56
- if input_option == "Use Hugging Face Dataset":
57
- dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
58
- if st.button("Load Dataset"):
59
- try:
60
- with st.spinner("Loading dataset..."):
61
- dataset = load_dataset(dataset_name, split="train")
62
- st.session_state.df = pd.DataFrame(dataset)
63
- st.success(f"Dataset '{dataset_name}' loaded successfully!")
64
- st.dataframe(st.session_state.df.head())
65
- except Exception as e:
66
- st.error(f"Error: {e}")
67
- elif input_option == "Upload CSV File":
68
- uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
69
- if uploaded_file:
70
- st.session_state.df = pd.read_csv(uploaded_file)
71
- st.success("File uploaded successfully!")
72
- st.dataframe(st.session_state.df.head())
73
-
74
- # SQL-RAG Analysis
75
- if st.session_state.df is not None:
76
- temp_dir = tempfile.TemporaryDirectory()
77
- db_path = os.path.join(temp_dir.name, "data.db")
78
- connection = sqlite3.connect(db_path)
79
- st.session_state.df.to_sql("salaries", connection, if_exists="replace", index=False)
80
- db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
81
-
82
- @tool("list_tables")
83
- def list_tables() -> str:
84
- """List all tables in the database."""
85
- return ListSQLDatabaseTool(db=db).invoke("")
86
-
87
- @tool("tables_schema")
88
- def tables_schema(tables: str) -> str:
89
- """Get schema and sample rows for given tables."""
90
- return InfoSQLDatabaseTool(db=db).invoke(tables)
91
-
92
- @tool("execute_sql")
93
- def execute_sql(sql_query: str) -> str:
94
- """Execute a SQL query against the database."""
95
- return QuerySQLDataBaseTool(db=db).invoke(sql_query)
96
-
97
- @tool("check_sql")
98
- def check_sql(sql_query: str) -> str:
99
- """Check the validity of a SQL query."""
100
- return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
101
-
102
- sql_dev = Agent(
103
- role="Senior Database Developer",
104
- goal="Extract data using optimized SQL queries.",
105
- backstory="An expert in writing optimized SQL queries for complex databases.",
106
- llm=llm,
107
- tools=[list_tables, tables_schema, execute_sql, check_sql],
108
- )
109
-
110
- data_analyst = Agent(
111
- role="Senior Data Analyst",
112
- goal="Analyze the data and produce insights.",
113
- backstory="A seasoned analyst who identifies trends and patterns in datasets.",
114
- llm=llm,
115
- )
116
-
117
- report_writer = Agent(
118
- role="Technical Report Writer",
119
- goal="Summarize the insights into a clear report.",
120
- backstory="An expert in summarizing data insights into readable reports.",
121
- llm=llm,
122
- )
123
-
124
- extract_data = Task(
125
- description="Extract data based on the query: {query}.",
126
- expected_output="Database results matching the query.",
127
- agent=sql_dev,
128
- )
129
-
130
- analyze_data = Task(
131
- description="Analyze the extracted data for query: {query}.",
132
- expected_output="Analysis text summarizing findings.",
133
- agent=data_analyst,
134
- context=[extract_data],
135
- )
136
-
137
- write_report = Task(
138
- description="Summarize the analysis into an executive report.",
139
- expected_output="Markdown report of insights.",
140
- agent=report_writer,
141
- context=[analyze_data],
142
- )
143
-
144
- crew = Crew(
145
- agents=[sql_dev, data_analyst, report_writer],
146
- tasks=[extract_data, analyze_data, write_report],
147
- process=Process.sequential,
148
- verbose=True,
149
- )
150
-
151
- query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary for senior employees?'")
152
- if st.button("Submit Query"):
153
- with st.spinner("Processing query..."):
154
- inputs = {"query": query}
155
- result = crew.kickoff(inputs=inputs)
156
- st.markdown("### Analysis Report:")
157
- st.markdown(result)
158
-
159
- temp_dir.cleanup()
160
- else:
161
- st.info("Please load a dataset to proceed.")