diff --git a/_pytest/core.py b/_pytest/core.py index af888024e..f2c120cf3 100644 --- a/_pytest/core.py +++ b/_pytest/core.py @@ -67,20 +67,39 @@ class TagTracerSub: def get(self, name): return self.__class__(self.root, self.tags + (name,)) + def add_method_controller(cls, func): + """ Use func as the method controler for the method found + at the class named func.__name__. + + A method controler is invoked with the same arguments + as the function it substitutes and is required to yield once + which will trigger calling the controlled method. + If it yields a second value, the value will be returned + as the result of the invocation. Errors in the controlled function + are re-raised to the controller during the first yield. + """ name = func.__name__ oldcall = getattr(cls, name) def wrap_exec(*args, **kwargs): gen = func(*args, **kwargs) next(gen) # first yield - res = oldcall(*args, **kwargs) try: - gen.send(res) - except StopIteration: - pass + res = oldcall(*args, **kwargs) + except Exception: + excinfo = sys.exc_info() + try: + # reraise exception to controller + res = gen.throw(*excinfo) + except StopIteration: + py.builtin._reraise(*excinfo) else: - raise ValueError("expected StopIteration") + try: + res = gen.send(res) + except StopIteration: + pass return res + setattr(cls, name, wrap_exec) return lambda: setattr(cls, name, oldcall) diff --git a/testing/test_core.py b/testing/test_core.py index 5e7113974..03bc4813e 100644 --- a/testing/test_core.py +++ b/testing/test_core.py @@ -765,22 +765,96 @@ def test_importplugin_issue375(testdir): assert "qwe" not in str(excinfo.value) assert "aaaa" in str(excinfo.value) +class TestWrapMethod: + def test_basic_happypath(self): + class A: + def f(self): + return "A.f" -def test_wrapping(): - class A: + l = [] def f(self): - return "A.f" + l.append(1) + yield + l.append(2) + undo = add_method_controller(A, f) - l = [] - def f(self): - l.append(1) - yield - l.append(2) - undo = add_method_controller(A, f) + assert A().f() == "A.f" + assert l == [1,2] + undo() + l[:] = [] + assert A().f() == "A.f" + assert l == [] - assert A().f() == "A.f" - assert l == [1,2] - undo() - l[:] = [] - assert A().f() == "A.f" - assert l == [] + def test_method_raises(self): + class A: + def error(self, val): + raise ValueError(val) + + l = [] + def error(self, val): + l.append(val) + try: + yield + except ValueError: + l.append(None) + raise + + + undo = add_method_controller(A, error) + + with pytest.raises(ValueError): + A().error(42) + assert l == [42, None] + undo() + l[:] = [] + with pytest.raises(ValueError): + A().error(42) + assert l == [] + + def test_controller_swallows_method_raises(self): + class A: + def error(self, val): + raise ValueError(val) + + def error(self, val): + try: + yield + except ValueError: + yield 2 + + add_method_controller(A, error) + assert A().error(42) == 2 + + def test_reraise_on_controller_StopIteration(self): + class A: + def error(self, val): + raise ValueError(val) + + def error(self, val): + try: + yield + except ValueError: + pass + + add_method_controller(A, error) + with pytest.raises(ValueError): + A().error(42) + + @pytest.mark.xfail(reason="if needed later") + def test_modify_call_args(self): + class A: + def error(self, val1, val2): + raise ValueError(val1+val2) + + l = [] + def error(self): + try: + yield (1,), {'val2': 2} + except ValueError as ex: + assert ex.args == (3,) + l.append(1) + + add_method_controller(A, error) + with pytest.raises(ValueError): + A().error() + assert l == [1]