Project

General

Profile

Statistics
| Branch: | Revision:

root / env / lib / python2.7 / site-packages / django / utils / unittest / loader.py @ 1a305335

History | View | Annotate | Download (13.1 KB)

1
"""Loading unittests."""
2

    
3
import os
4
import re
5
import sys
6
import traceback
7
import types
8
import unittest
9

    
10
from fnmatch import fnmatch
11

    
12
from django.utils.unittest import case, suite
13

    
14
try:
15
    from os.path import relpath
16
except ImportError:
17
    from django.utils.unittest.compatibility import relpath
18

    
19
__unittest = True
20

    
21

    
22
def _CmpToKey(mycmp):
23
    'Convert a cmp= function into a key= function'
24
    class K(object):
25
        def __init__(self, obj):
26
            self.obj = obj
27
        def __lt__(self, other):
28
            return mycmp(self.obj, other.obj) == -1
29
    return K
30

    
31

    
32
# what about .pyc or .pyo (etc)
33
# we would need to avoid loading the same tests multiple times
34
# from '.py', '.pyc' *and* '.pyo'
35
VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
36

    
37

    
38
def _make_failed_import_test(name, suiteClass):
39
    message = 'Failed to import test module: %s' % name
40
    if hasattr(traceback, 'format_exc'):
41
        # Python 2.3 compatibility
42
        # format_exc returns two frames of discover.py as well
43
        message += '\n%s' % traceback.format_exc()
44
    return _make_failed_test('ModuleImportFailure', name, ImportError(message),
45
                             suiteClass)
46

    
47
def _make_failed_load_tests(name, exception, suiteClass):
48
    return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
49

    
50
def _make_failed_test(classname, methodname, exception, suiteClass):
51
    def testFailure(self):
52
        raise exception
53
    attrs = {methodname: testFailure}
54
    TestClass = type(classname, (case.TestCase,), attrs)
55
    return suiteClass((TestClass(methodname),))
56

    
57

    
58
class TestLoader(unittest.TestLoader):
59
    """
60
    This class is responsible for loading tests according to various criteria
61
    and returning them wrapped in a TestSuite
62
    """
63
    testMethodPrefix = 'test'
64
    sortTestMethodsUsing = cmp
65
    suiteClass = suite.TestSuite
66
    _top_level_dir = None
67

    
68
    def loadTestsFromTestCase(self, testCaseClass):
69
        """Return a suite of all tests cases contained in testCaseClass"""
70
        if issubclass(testCaseClass, suite.TestSuite):
71
            raise TypeError("Test cases should not be derived from TestSuite."
72
                            " Maybe you meant to derive from TestCase?")
73
        testCaseNames = self.getTestCaseNames(testCaseClass)
74
        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
75
            testCaseNames = ['runTest']
76
        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
77
        return loaded_suite
78

    
79
    def loadTestsFromModule(self, module, use_load_tests=True):
80
        """Return a suite of all tests cases contained in the given module"""
81
        tests = []
82
        for name in dir(module):
83
            obj = getattr(module, name)
84
            if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
85
                tests.append(self.loadTestsFromTestCase(obj))
86

    
87
        load_tests = getattr(module, 'load_tests', None)
88
        tests = self.suiteClass(tests)
89
        if use_load_tests and load_tests is not None:
90
            try:
91
                return load_tests(self, tests, None)
92
            except Exception, e:
93
                return _make_failed_load_tests(module.__name__, e,
94
                                               self.suiteClass)
95
        return tests
96

    
97
    def loadTestsFromName(self, name, module=None):
98
        """Return a suite of all tests cases given a string specifier.
99

100
        The name may resolve either to a module, a test case class, a
101
        test method within a test case class, or a callable object which
102
        returns a TestCase or TestSuite instance.
103

104
        The method optionally resolves the names relative to a given module.
105
        """
106
        parts = name.split('.')
107
        if module is None:
108
            parts_copy = parts[:]
109
            while parts_copy:
110
                try:
111
                    module = __import__('.'.join(parts_copy))
112
                    break
113
                except ImportError:
114
                    del parts_copy[-1]
115
                    if not parts_copy:
116
                        raise
117
            parts = parts[1:]
118
        obj = module
119
        for part in parts:
120
            parent, obj = obj, getattr(obj, part)
121

    
122
        if isinstance(obj, types.ModuleType):
123
            return self.loadTestsFromModule(obj)
124
        elif isinstance(obj, type) and issubclass(obj, unittest.TestCase):
125
            return self.loadTestsFromTestCase(obj)
126
        elif (isinstance(obj, types.UnboundMethodType) and
127
              isinstance(parent, type) and
128
              issubclass(parent, case.TestCase)):
129
            return self.suiteClass([parent(obj.__name__)])
130
        elif isinstance(obj, unittest.TestSuite):
131
            return obj
132
        elif hasattr(obj, '__call__'):
133
            test = obj()
134
            if isinstance(test, unittest.TestSuite):
135
                return test
136
            elif isinstance(test, unittest.TestCase):
137
                return self.suiteClass([test])
138
            else:
139
                raise TypeError("calling %s returned %s, not a test" %
140
                                (obj, test))
141
        else:
142
            raise TypeError("don't know how to make test from: %s" % obj)
143

    
144
    def loadTestsFromNames(self, names, module=None):
145
        """Return a suite of all tests cases found using the given sequence
146
        of string specifiers. See 'loadTestsFromName()'.
147
        """
148
        suites = [self.loadTestsFromName(name, module) for name in names]
149
        return self.suiteClass(suites)
150

    
151
    def getTestCaseNames(self, testCaseClass):
152
        """Return a sorted sequence of method names found within testCaseClass
153
        """
154
        def isTestMethod(attrname, testCaseClass=testCaseClass,
155
                         prefix=self.testMethodPrefix):
156
            return attrname.startswith(prefix) and \
157
                hasattr(getattr(testCaseClass, attrname), '__call__')
158
        testFnNames = filter(isTestMethod, dir(testCaseClass))
159
        if self.sortTestMethodsUsing:
160
            testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
161
        return testFnNames
162

    
163
    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
164
        """Find and return all test modules from the specified start
165
        directory, recursing into subdirectories to find them. Only test files
166
        that match the pattern will be loaded. (Using shell style pattern
167
        matching.)
168

169
        All test modules must be importable from the top level of the project.
170
        If the start directory is not the top level directory then the top
171
        level directory must be specified separately.
172

173
        If a test package name (directory with '__init__.py') matches the
174
        pattern then the package will be checked for a 'load_tests' function. If
175
        this exists then it will be called with loader, tests, pattern.
176

177
        If load_tests exists then discovery does  *not* recurse into the package,
178
        load_tests is responsible for loading all tests in the package.
179

180
        The pattern is deliberately not stored as a loader attribute so that
181
        packages can continue discovery themselves. top_level_dir is stored so
182
        load_tests does not need to pass this argument in to loader.discover().
183
        """
184
        set_implicit_top = False
185
        if top_level_dir is None and self._top_level_dir is not None:
186
            # make top_level_dir optional if called from load_tests in a package
187
            top_level_dir = self._top_level_dir
188
        elif top_level_dir is None:
189
            set_implicit_top = True
190
            top_level_dir = start_dir
191

    
192
        top_level_dir = os.path.abspath(top_level_dir)
193

    
194
        if not top_level_dir in sys.path:
195
            # all test modules must be importable from the top level directory
196
            # should we *unconditionally* put the start directory in first
197
            # in sys.path to minimise likelihood of conflicts between installed
198
            # modules and development versions?
199
            sys.path.insert(0, top_level_dir)
200
        self._top_level_dir = top_level_dir
201

    
202
        is_not_importable = False
203
        if os.path.isdir(os.path.abspath(start_dir)):
204
            start_dir = os.path.abspath(start_dir)
205
            if start_dir != top_level_dir:
206
                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
207
        else:
208
            # support for discovery from dotted module names
209
            try:
210
                __import__(start_dir)
211
            except ImportError:
212
                is_not_importable = True
213
            else:
214
                the_module = sys.modules[start_dir]
215
                top_part = start_dir.split('.')[0]
216
                start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
217
                if set_implicit_top:
218
                    self._top_level_dir = os.path.abspath(os.path.dirname(os.path.dirname(sys.modules[top_part].__file__)))
219
                    sys.path.remove(top_level_dir)
220

    
221
        if is_not_importable:
222
            raise ImportError('Start directory is not importable: %r' % start_dir)
223

    
224
        tests = list(self._find_tests(start_dir, pattern))
225
        return self.suiteClass(tests)
226

    
227
    def _get_name_from_path(self, path):
228
        path = os.path.splitext(os.path.normpath(path))[0]
229

    
230
        _relpath = relpath(path, self._top_level_dir)
231
        assert not os.path.isabs(_relpath), "Path must be within the project"
232
        assert not _relpath.startswith('..'), "Path must be within the project"
233

    
234
        name = _relpath.replace(os.path.sep, '.')
235
        return name
236

    
237
    def _get_module_from_name(self, name):
238
        __import__(name)
239
        return sys.modules[name]
240

    
241
    def _match_path(self, path, full_path, pattern):
242
        # override this method to use alternative matching strategy
243
        return fnmatch(path, pattern)
244

    
245
    def _find_tests(self, start_dir, pattern):
246
        """Used by discovery. Yields test suites it loads."""
247
        paths = os.listdir(start_dir)
248

    
249
        for path in paths:
250
            full_path = os.path.join(start_dir, path)
251
            if os.path.isfile(full_path):
252
                if not VALID_MODULE_NAME.match(path):
253
                    # valid Python identifiers only
254
                    continue
255
                if not self._match_path(path, full_path, pattern):
256
                    continue
257
                # if the test file matches, load it
258
                name = self._get_name_from_path(full_path)
259
                try:
260
                    module = self._get_module_from_name(name)
261
                except:
262
                    yield _make_failed_import_test(name, self.suiteClass)
263
                else:
264
                    mod_file = os.path.abspath(getattr(module, '__file__', full_path))
265
                    realpath = os.path.splitext(mod_file)[0]
266
                    fullpath_noext = os.path.splitext(full_path)[0]
267
                    if realpath.lower() != fullpath_noext.lower():
268
                        module_dir = os.path.dirname(realpath)
269
                        mod_name = os.path.splitext(os.path.basename(full_path))[0]
270
                        expected_dir = os.path.dirname(full_path)
271
                        msg = ("%r module incorrectly imported from %r. Expected %r. "
272
                               "Is this module globally installed?")
273
                        raise ImportError(msg % (mod_name, module_dir, expected_dir))
274
                    yield self.loadTestsFromModule(module)
275
            elif os.path.isdir(full_path):
276
                if not os.path.isfile(os.path.join(full_path, '__init__.py')):
277
                    continue
278

    
279
                load_tests = None
280
                tests = None
281
                if fnmatch(path, pattern):
282
                    # only check load_tests if the package directory itself matches the filter
283
                    name = self._get_name_from_path(full_path)
284
                    package = self._get_module_from_name(name)
285
                    load_tests = getattr(package, 'load_tests', None)
286
                    tests = self.loadTestsFromModule(package, use_load_tests=False)
287

    
288
                if load_tests is None:
289
                    if tests is not None:
290
                        # tests loaded from package file
291
                        yield tests
292
                    # recurse into the package
293
                    for test in self._find_tests(full_path, pattern):
294
                        yield test
295
                else:
296
                    try:
297
                        yield load_tests(self, tests, pattern)
298
                    except Exception, e:
299
                        yield _make_failed_load_tests(package.__name__, e,
300
                                                      self.suiteClass)
301

    
302
defaultTestLoader = TestLoader()
303

    
304

    
305
def _makeLoader(prefix, sortUsing, suiteClass=None):
306
    loader = TestLoader()
307
    loader.sortTestMethodsUsing = sortUsing
308
    loader.testMethodPrefix = prefix
309
    if suiteClass:
310
        loader.suiteClass = suiteClass
311
    return loader
312

    
313
def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
314
    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
315

    
316
def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
317
              suiteClass=suite.TestSuite):
318
    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
319

    
320
def findTestCases(module, prefix='test', sortUsing=cmp,
321
                  suiteClass=suite.TestSuite):
322
    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)