fix add_method_controller to deal properly in the event of exceptions.
add a docstring as well.
This commit is contained in:
		
							parent
							
								
									6ab36592ea
								
							
						
					
					
						commit
						a43fb9cd93
					
				|  | @ -67,20 +67,39 @@ class TagTracerSub: | ||||||
|     def get(self, name): |     def get(self, name): | ||||||
|         return self.__class__(self.root, self.tags + (name,)) |         return self.__class__(self.root, self.tags + (name,)) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def add_method_controller(cls, func): | def add_method_controller(cls, func): | ||||||
|  |     """ Use func as the method controler for the method found | ||||||
|  |     at the class named func.__name__. | ||||||
|  | 
 | ||||||
|  |     A method controler is invoked with the same arguments | ||||||
|  |     as the function it substitutes and is required to yield once | ||||||
|  |     which will trigger calling the controlled method. | ||||||
|  |     If it yields a second value, the value will be returned | ||||||
|  |     as the result of the invocation.  Errors in the controlled function | ||||||
|  |     are re-raised to the controller during the first yield. | ||||||
|  |     """ | ||||||
|     name = func.__name__ |     name = func.__name__ | ||||||
|     oldcall = getattr(cls, name) |     oldcall = getattr(cls, name) | ||||||
|     def wrap_exec(*args, **kwargs): |     def wrap_exec(*args, **kwargs): | ||||||
|         gen = func(*args, **kwargs) |         gen = func(*args, **kwargs) | ||||||
|         next(gen)   # first yield |         next(gen)   # first yield | ||||||
|         res = oldcall(*args, **kwargs) |  | ||||||
|         try: |         try: | ||||||
|             gen.send(res) |             res = oldcall(*args, **kwargs) | ||||||
|         except StopIteration: |         except Exception: | ||||||
|             pass |             excinfo = sys.exc_info() | ||||||
|  |             try: | ||||||
|  |                 # reraise exception to controller | ||||||
|  |                 res = gen.throw(*excinfo) | ||||||
|  |             except StopIteration: | ||||||
|  |                 py.builtin._reraise(*excinfo) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("expected StopIteration") |             try: | ||||||
|  |                 res = gen.send(res) | ||||||
|  |             except StopIteration: | ||||||
|  |                 pass | ||||||
|         return res |         return res | ||||||
|  | 
 | ||||||
|     setattr(cls, name, wrap_exec) |     setattr(cls, name, wrap_exec) | ||||||
|     return lambda: setattr(cls, name, oldcall) |     return lambda: setattr(cls, name, oldcall) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -765,22 +765,96 @@ def test_importplugin_issue375(testdir): | ||||||
|     assert "qwe" not in str(excinfo.value) |     assert "qwe" not in str(excinfo.value) | ||||||
|     assert "aaaa" in str(excinfo.value) |     assert "aaaa" in str(excinfo.value) | ||||||
| 
 | 
 | ||||||
|  | class TestWrapMethod: | ||||||
|  |     def test_basic_happypath(self): | ||||||
|  |         class A: | ||||||
|  |             def f(self): | ||||||
|  |                 return "A.f" | ||||||
| 
 | 
 | ||||||
| def test_wrapping(): |         l = [] | ||||||
|     class A: |  | ||||||
|         def f(self): |         def f(self): | ||||||
|             return "A.f" |             l.append(1) | ||||||
|  |             yield | ||||||
|  |             l.append(2) | ||||||
|  |         undo = add_method_controller(A, f) | ||||||
| 
 | 
 | ||||||
|     l = [] |         assert A().f() == "A.f" | ||||||
|     def f(self): |         assert l == [1,2] | ||||||
|         l.append(1) |         undo() | ||||||
|         yield |         l[:] = [] | ||||||
|         l.append(2) |         assert A().f() == "A.f" | ||||||
|     undo = add_method_controller(A, f) |         assert l == [] | ||||||
| 
 | 
 | ||||||
|     assert A().f() == "A.f" |     def test_method_raises(self): | ||||||
|     assert l == [1,2] |         class A: | ||||||
|     undo() |             def error(self, val): | ||||||
|     l[:] = [] |                 raise ValueError(val) | ||||||
|     assert A().f() == "A.f" | 
 | ||||||
|     assert l == [] |         l = [] | ||||||
|  |         def error(self, val): | ||||||
|  |             l.append(val) | ||||||
|  |             try: | ||||||
|  |                 yield | ||||||
|  |             except ValueError: | ||||||
|  |                 l.append(None) | ||||||
|  |                 raise | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         undo = add_method_controller(A, error) | ||||||
|  | 
 | ||||||
|  |         with pytest.raises(ValueError): | ||||||
|  |             A().error(42) | ||||||
|  |         assert l == [42, None] | ||||||
|  |         undo() | ||||||
|  |         l[:] = [] | ||||||
|  |         with pytest.raises(ValueError): | ||||||
|  |             A().error(42) | ||||||
|  |         assert l == [] | ||||||
|  | 
 | ||||||
|  |     def test_controller_swallows_method_raises(self): | ||||||
|  |         class A: | ||||||
|  |             def error(self, val): | ||||||
|  |                 raise ValueError(val) | ||||||
|  | 
 | ||||||
|  |         def error(self, val): | ||||||
|  |             try: | ||||||
|  |                 yield | ||||||
|  |             except ValueError: | ||||||
|  |                 yield 2 | ||||||
|  | 
 | ||||||
|  |         add_method_controller(A, error) | ||||||
|  |         assert A().error(42) == 2 | ||||||
|  | 
 | ||||||
|  |     def test_reraise_on_controller_StopIteration(self): | ||||||
|  |         class A: | ||||||
|  |             def error(self, val): | ||||||
|  |                 raise ValueError(val) | ||||||
|  | 
 | ||||||
|  |         def error(self, val): | ||||||
|  |             try: | ||||||
|  |                 yield | ||||||
|  |             except ValueError: | ||||||
|  |                 pass | ||||||
|  | 
 | ||||||
|  |         add_method_controller(A, error) | ||||||
|  |         with pytest.raises(ValueError): | ||||||
|  |             A().error(42) | ||||||
|  | 
 | ||||||
|  |     @pytest.mark.xfail(reason="if needed later") | ||||||
|  |     def test_modify_call_args(self): | ||||||
|  |         class A: | ||||||
|  |             def error(self, val1, val2): | ||||||
|  |                 raise ValueError(val1+val2) | ||||||
|  | 
 | ||||||
|  |         l = [] | ||||||
|  |         def error(self): | ||||||
|  |             try: | ||||||
|  |                 yield (1,), {'val2': 2} | ||||||
|  |             except ValueError as ex: | ||||||
|  |                 assert ex.args == (3,) | ||||||
|  |                 l.append(1) | ||||||
|  | 
 | ||||||
|  |         add_method_controller(A, error) | ||||||
|  |         with pytest.raises(ValueError): | ||||||
|  |             A().error() | ||||||
|  |         assert l == [1] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue