File size: 2,287 Bytes
105b369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from pathlib import Path
from typing import Optional, Union

from phi.tools import Toolkit
from phi.utils.log import logger


class AirflowToolkit(Toolkit):
    def __init__(self, dags_dir: Optional[Union[Path, str]] = None, save_dag: bool = True, read_dag: bool = True):
        super().__init__(name="AirflowTools")

        _dags_dir: Optional[Path] = None
        if dags_dir is not None:
            if isinstance(dags_dir, str):
                _dags_dir = Path.cwd().joinpath(dags_dir)
            else:
                _dags_dir = dags_dir
        self.dags_dir: Path = _dags_dir or Path.cwd()
        if save_dag:
            self.register(self.save_dag_file, sanitize_arguments=False)
        if read_dag:
            self.register(self.read_dag_file)

    def save_dag_file(self, contents: str, dag_file: str) -> str:
        """Saves python code for an Airflow DAG to a file called `dag_file` and returns the file path if successful.

        :param contents: The contents of the DAG.
        :param dag_file: The name of the file to save to.
        :return: The file path if successful, otherwise returns an error message.
        """
        try:
            file_path = self.dags_dir.joinpath(dag_file)
            logger.debug(f"Saving contents to {file_path}")
            if not file_path.parent.exists():
                file_path.parent.mkdir(parents=True, exist_ok=True)
            file_path.write_text(contents)
            logger.info(f"Saved: {file_path}")
            return str(str(file_path))
        except Exception as e:
            logger.error(f"Error saving to file: {e}")
            return f"Error saving to file: {e}"

    def read_dag_file(self, dag_file: str) -> str:
        """Reads an Airflow DAG file `dag_file` and returns the contents if successful.

        :param dag_file: The name of the file to read
        :return: The contents of the file if successful, otherwise returns an error message.
        """
        try:
            logger.info(f"Reading file: {dag_file}")
            file_path = self.dags_dir.joinpath(dag_file)
            contents = file_path.read_text()
            return str(contents)
        except Exception as e:
            logger.error(f"Error reading file: {e}")
            return f"Error reading file: {e}"