mgbam commited on
Commit
17b2c9a
·
verified ·
1 Parent(s): 773f0cf

Update db_connector.py

Browse files
Files changed (1) hide show
  1. db_connector.py +33 -9
db_connector.py CHANGED
@@ -1,17 +1,41 @@
1
-
 
2
  import pandas as pd
3
  from sqlalchemy import create_engine, inspect
4
 
5
- SUPPORTED_ENGINES = ["SQLite", "PostgreSQL", "MySQL", "MSSQL", "Oracle"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- def list_tables(conn_str):
8
  engine = create_engine(conn_str)
9
- inspector = inspect(engine)
10
- return inspector.get_table_names()
11
 
12
- def fetch_data_from_db(conn_str, table):
13
  engine = create_engine(conn_str)
14
  df = pd.read_sql_table(table, engine)
15
- csv_path = f"data/{table}_extracted.csv"
16
- df.to_csv(csv_path, index=False)
17
- return csv_path
 
 
 
 
 
 
1
+ # db_connector.py
2
+ import os
3
  import pandas as pd
4
  from sqlalchemy import create_engine, inspect
5
 
6
+ SUPPORTED_ENGINES = ["PostgreSQL", "MySQL", "SQLite", "MSSQL", "Oracle"]
7
+
8
+ def _env_to_uri(engine: str) -> str | None:
9
+ """Build SQLAlchemy URI from env vars; return None if vars missing."""
10
+ match engine:
11
+ case "PostgreSQL":
12
+ host = os.getenv("PG_HOST"); port = os.getenv("PG_PORT", "5432")
13
+ db = os.getenv("PG_DB"); user = os.getenv("PG_USER"); pw = os.getenv("PG_PW")
14
+ if all([host, db, user, pw]):
15
+ return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
16
+ case "MySQL":
17
+ host = os.getenv("MYSQL_HOST"); port = os.getenv("MYSQL_PORT", "3306")
18
+ db = os.getenv("MYSQL_DB"); user = os.getenv("MYSQL_USER"); pw = os.getenv("MYSQL_PW")
19
+ if all([host, db, user, pw]):
20
+ return f"mysql+mysqlconnector://{user}:{pw}@{host}:{port}/{db}"
21
+ case "MSSQL":
22
+ if os.getenv("MSSQL_CONN_STR"):
23
+ return os.getenv("MSSQL_CONN_STR")
24
+ # add Oracle, etc.
25
+ return None
26
 
27
+ def list_tables(conn_str: str):
28
  engine = create_engine(conn_str)
29
+ return inspect(engine).get_table_names()
 
30
 
31
+ def fetch_data_from_db(conn_str: str, table: str) -> str:
32
  engine = create_engine(conn_str)
33
  df = pd.read_sql_table(table, engine)
34
+ tmp_path = os.path.join(tempfile.gettempdir(), f"{table}.csv")
35
+ df.to_csv(tmp_path, index=False)
36
+ return tmp_path
37
+
38
+ def get_connection_string(engine: str, manual_input: str | None) -> str | None:
39
+ """Prefer env‑vars; fallback to user input."""
40
+ auto = _env_to_uri(engine)
41
+ return auto or manual_input