File size: 10,895 Bytes
d1ceb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""
Module is used to infer Django model fields.
"""
from inspect import Parameter

from jedi import debug
from jedi.inference.cache import inference_state_function_cache
from jedi.inference.base_value import ValueSet, iterator_to_value_set, ValueWrapper
from jedi.inference.filters import DictFilter, AttributeOverwrite
from jedi.inference.names import NameWrapper, BaseTreeParamName
from jedi.inference.compiled.value import EmptyCompiledName
from jedi.inference.value.instance import TreeInstance
from jedi.inference.value.klass import ClassMixin
from jedi.inference.gradual.base import GenericClass
from jedi.inference.gradual.generics import TupleGenericManager
from jedi.inference.signature import AbstractSignature


mapping = {
    'IntegerField': (None, 'int'),
    'BigIntegerField': (None, 'int'),
    'PositiveIntegerField': (None, 'int'),
    'SmallIntegerField': (None, 'int'),
    'CharField': (None, 'str'),
    'TextField': (None, 'str'),
    'EmailField': (None, 'str'),
    'GenericIPAddressField': (None, 'str'),
    'URLField': (None, 'str'),
    'FloatField': (None, 'float'),
    'BinaryField': (None, 'bytes'),
    'BooleanField': (None, 'bool'),
    'DecimalField': ('decimal', 'Decimal'),
    'TimeField': ('datetime', 'time'),
    'DurationField': ('datetime', 'timedelta'),
    'DateField': ('datetime', 'date'),
    'DateTimeField': ('datetime', 'datetime'),
    'UUIDField': ('uuid', 'UUID'),
}

_FILTER_LIKE_METHODS = ('create', 'filter', 'exclude', 'update', 'get',
                        'get_or_create', 'update_or_create')


@inference_state_function_cache()
def _get_deferred_attributes(inference_state):
    return inference_state.import_module(
        ('django', 'db', 'models', 'query_utils')
    ).py__getattribute__('DeferredAttribute').execute_annotation()


def _infer_scalar_field(inference_state, field_name, field_tree_instance, is_instance):
    try:
        module_name, attribute_name = mapping[field_tree_instance.py__name__()]
    except KeyError:
        return None

    if not is_instance:
        return _get_deferred_attributes(inference_state)

    if module_name is None:
        module = inference_state.builtins_module
    else:
        module = inference_state.import_module((module_name,))

    for attribute in module.py__getattribute__(attribute_name):
        return attribute.execute_with_values()


@iterator_to_value_set
def _get_foreign_key_values(cls, field_tree_instance):
    if isinstance(field_tree_instance, TreeInstance):
        # TODO private access..
        argument_iterator = field_tree_instance._arguments.unpack()
        key, lazy_values = next(argument_iterator, (None, None))
        if key is None and lazy_values is not None:
            for value in lazy_values.infer():
                if value.py__name__() == 'str':
                    foreign_key_class_name = value.get_safe_value()
                    module = cls.get_root_context()
                    for v in module.py__getattribute__(foreign_key_class_name):
                        if v.is_class():
                            yield v
                elif value.is_class():
                    yield value


def _infer_field(cls, field_name, is_instance):
    inference_state = cls.inference_state
    result = field_name.infer()
    for field_tree_instance in result:
        scalar_field = _infer_scalar_field(
            inference_state, field_name, field_tree_instance, is_instance)
        if scalar_field is not None:
            return scalar_field

        name = field_tree_instance.py__name__()
        is_many_to_many = name == 'ManyToManyField'
        if name in ('ForeignKey', 'OneToOneField') or is_many_to_many:
            if not is_instance:
                return _get_deferred_attributes(inference_state)

            values = _get_foreign_key_values(cls, field_tree_instance)
            if is_many_to_many:
                return ValueSet(filter(None, [
                    _create_manager_for(v, 'RelatedManager') for v in values
                ]))
            else:
                return values.execute_with_values()

    debug.dbg('django plugin: fail to infer `%s` from class `%s`',
              field_name.string_name, cls.py__name__())
    return result


class DjangoModelName(NameWrapper):
    def __init__(self, cls, name, is_instance):
        super().__init__(name)
        self._cls = cls
        self._is_instance = is_instance

    def infer(self):
        return _infer_field(self._cls, self._wrapped_name, self._is_instance)


def _create_manager_for(cls, manager_cls='BaseManager'):
    managers = cls.inference_state.import_module(
        ('django', 'db', 'models', 'manager')
    ).py__getattribute__(manager_cls)
    for m in managers:
        if m.is_class_mixin():
            generics_manager = TupleGenericManager((ValueSet([cls]),))
            for c in GenericClass(m, generics_manager).execute_annotation():
                return c
    return None


def _new_dict_filter(cls, is_instance):
    filters = list(cls.get_filters(
        is_instance=is_instance,
        include_metaclasses=False,
        include_type_when_class=False)
    )
    dct = {
        name.string_name: DjangoModelName(cls, name, is_instance)
        for filter_ in reversed(filters)
        for name in filter_.values()
    }
    if is_instance:
        # Replace the objects with a name that amounts to nothing when accessed
        # in an instance. This is not perfect and still completes "objects" in
        # that case, but it at least not inferes stuff like `.objects.filter`.
        # It would be nicer to do that in a better way, so that it also doesn't
        # show up in completions, but it's probably just not worth doing that
        # for the extra amount of work.
        dct['objects'] = EmptyCompiledName(cls.inference_state, 'objects')

    return DictFilter(dct)


def is_django_model_base(value):
    return value.py__name__() == 'ModelBase' \
        and value.get_root_context().py__name__() == 'django.db.models.base'


def get_metaclass_filters(func):
    def wrapper(cls, metaclasses, is_instance):
        for metaclass in metaclasses:
            if is_django_model_base(metaclass):
                return [_new_dict_filter(cls, is_instance)]

        return func(cls, metaclasses, is_instance)
    return wrapper


def tree_name_to_values(func):
    def wrapper(inference_state, context, tree_name):
        result = func(inference_state, context, tree_name)
        if tree_name.value in _FILTER_LIKE_METHODS:
            # Here we try to overwrite stuff like User.objects.filter. We need
            # this to make sure that keyword param completion works on these
            # kind of methods.
            for v in result:
                if v.get_qualified_names() == ('_BaseQuerySet', tree_name.value) \
                        and v.parent_context.is_module() \
                        and v.parent_context.py__name__() == 'django.db.models.query':
                    qs = context.get_value()
                    generics = qs.get_generics()
                    if len(generics) >= 1:
                        return ValueSet(QuerySetMethodWrapper(v, model)
                                        for model in generics[0])

        elif tree_name.value == 'BaseManager' and context.is_module() \
                and context.py__name__() == 'django.db.models.manager':
            return ValueSet(ManagerWrapper(r) for r in result)

        elif tree_name.value == 'Field' and context.is_module() \
                and context.py__name__() == 'django.db.models.fields':
            return ValueSet(FieldWrapper(r) for r in result)
        return result
    return wrapper


def _find_fields(cls):
    for name in _new_dict_filter(cls, is_instance=False).values():
        for value in name.infer():
            if value.name.get_qualified_names(include_module_names=True) \
                    == ('django', 'db', 'models', 'query_utils', 'DeferredAttribute'):
                yield name


def _get_signatures(cls):
    return [DjangoModelSignature(cls, field_names=list(_find_fields(cls)))]


def get_metaclass_signatures(func):
    def wrapper(cls, metaclasses):
        for metaclass in metaclasses:
            if is_django_model_base(metaclass):
                return _get_signatures(cls)
        return func(cls, metaclass)
    return wrapper


class ManagerWrapper(ValueWrapper):
    def py__getitem__(self, index_value_set, contextualized_node):
        return ValueSet(
            GenericManagerWrapper(generic)
            for generic in self._wrapped_value.py__getitem__(
                index_value_set, contextualized_node)
        )


class GenericManagerWrapper(AttributeOverwrite, ClassMixin):
    def py__get__on_class(self, calling_instance, instance, class_value):
        return calling_instance.class_value.with_generics(
            (ValueSet({class_value}),)
        ).py__call__(calling_instance._arguments)

    def with_generics(self, generics_tuple):
        return self._wrapped_value.with_generics(generics_tuple)


class FieldWrapper(ValueWrapper):
    def py__getitem__(self, index_value_set, contextualized_node):
        return ValueSet(
            GenericFieldWrapper(generic)
            for generic in self._wrapped_value.py__getitem__(
                index_value_set, contextualized_node)
        )


class GenericFieldWrapper(AttributeOverwrite, ClassMixin):
    def py__get__on_class(self, calling_instance, instance, class_value):
        # This is mostly an optimization to avoid Jedi aborting inference,
        # because of too many function executions of Field.__get__.
        return ValueSet({calling_instance})


class DjangoModelSignature(AbstractSignature):
    def __init__(self, value, field_names):
        super().__init__(value)
        self._field_names = field_names

    def get_param_names(self, resolve_stars=False):
        return [DjangoParamName(name) for name in self._field_names]


class DjangoParamName(BaseTreeParamName):
    def __init__(self, field_name):
        super().__init__(field_name.parent_context, field_name.tree_name)
        self._field_name = field_name

    def get_kind(self):
        return Parameter.KEYWORD_ONLY

    def infer(self):
        return self._field_name.infer()


class QuerySetMethodWrapper(ValueWrapper):
    def __init__(self, method, model_cls):
        super().__init__(method)
        self._model_cls = model_cls

    def py__get__(self, instance, class_value):
        return ValueSet({QuerySetBoundMethodWrapper(v, self._model_cls)
                         for v in self._wrapped_value.py__get__(instance, class_value)})


class QuerySetBoundMethodWrapper(ValueWrapper):
    def __init__(self, method, model_cls):
        super().__init__(method)
        self._model_cls = model_cls

    def get_signatures(self):
        return _get_signatures(self._model_cls)