File size: 1,676 Bytes
22a452a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import inspect
import sys
from pathlib import Path
from typing import List, Type


root_dir = Path(__file__).parent.parent.absolute()
sys.path.insert(0, str(root_dir))

parser = argparse.ArgumentParser()
parser.add_argument("--type", type=str, default=None)
args = parser.parse_args()


def get_test_methods_from_class(cls: Type) -> List[str]:
    """
    Get all test method names from a given class.
    Only returns methods that start with 'test_'.
    """
    test_methods = []
    for name, obj in inspect.getmembers(cls):
        if name.startswith("test_") and inspect.isfunction(obj):
            test_methods.append(name)
    return sorted(test_methods)


def generate_pytest_pattern(test_methods: List[str]) -> str:
    """Generate pytest pattern string for the -k flag."""
    return " or ".join(test_methods)


def generate_pattern_for_mixin(mixin_class: Type) -> str:
    """
    Generate pytest pattern for a specific mixin class.
    """
    if mixin_cls is None:
        return ""
    test_methods = get_test_methods_from_class(mixin_class)
    return generate_pytest_pattern(test_methods)


if __name__ == "__main__":
    mixin_cls = None
    if args.type == "pipeline":
        from tests.pipelines.test_pipelines_common import PipelineTesterMixin

        mixin_cls = PipelineTesterMixin

    elif args.type == "models":
        from tests.models.test_modeling_common import ModelTesterMixin

        mixin_cls = ModelTesterMixin

    elif args.type == "lora":
        from tests.lora.utils import PeftLoraLoaderMixinTests

        mixin_cls = PeftLoraLoaderMixinTests

    pattern = generate_pattern_for_mixin(mixin_cls)
    print(pattern)