diff --git a/experiment.py b/experiment.py new file mode 100644 index 0000000..988974b --- /dev/null +++ b/experiment.py @@ -0,0 +1,15 @@ +from tests.test_utils import * + +@Commentor(fmt=[ + (type(None), lambda o: None), + +]) +def function(): + mystring = 'hello' + print(mystring) + mystring = "hello" + print(mystring) + return 1 + + +print(function()) diff --git a/tests/test_definitions.py b/tests/test_definitions.py index ae89efe..7d11a7d 100644 --- a/tests/test_definitions.py +++ b/tests/test_definitions.py @@ -11,3 +11,26 @@ def test_function_def(): def target(): pass ''') + + +def test_return(): + + with closing(StringIO()) as f: + + @Commentor(f) + def target(): + a = 1 + return a + 1 + + assert target() == 2 + + asserteq_or_print( + f.getvalue(), ''' + def target(): + a = 1 + return a + 1 + """ + 1 : a + 2 : a + 1 + """ +''') diff --git a/tests/test_utils.py b/tests/test_utils.py index b6b9dad..37dfb9d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,11 @@ +from io import StringIO +from contextlib import closing from trace_commentor import flags, Commentor def asserteq_or_print(value, ground_truth): if flags.DEBUG or flags.PRINT: print(value) else: - assert value == ground_truth.strip("\n"), "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth.strip("\n"), ">>>>>>>> GROUND\n"]) + value = value.strip("\n").rstrip(" ") + ground_truth = ground_truth.strip("\n").rstrip(" ") + assert value == ground_truth, "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth, ">>>>>>>> GROUND\n"]) diff --git a/trace_commentor/commentor.py b/trace_commentor/commentor.py index 14702b9..0d92adc 100644 --- a/trace_commentor/commentor.py +++ b/trace_commentor/commentor.py @@ -5,6 +5,7 @@ import rich from inspect import getfullargspec from functools import wraps +from io import IOBase from . import handlers from . import formatters @@ -51,9 +52,13 @@ class Commentor(object): rich.print(code, file=sys.stderr) elif self.file == "": rich.print(code, file=sys.stdout) - else: + elif isinstance(self.file, IOBase): + rich.print(code, file=self.file) + elif type(self.file) == str: with open(self.file, "wt") as f: rich.print(code, file=f) + else: + raise NotImplementedError(f"Unknown file protocal {self.file}") return self._return # } @@ -66,11 +71,16 @@ class Commentor(object): raise NotImplementedError(f"Unknown how to handle {node_type} node.") return handler(node, self) - def eval(self, node: ast.Expr): + def eval(self, node: ast.Expr, format=True): src = to_source(node) obj = eval(src, self._globals, self._locals) + if not format: + return obj + fmt = self.get_formatter(obj) - return f"{fmt(obj)} : {src}" + fmt_obj = fmt(obj) + if fmt_obj is not None: + return f"{fmt(obj)} : {src}" def exec(self, node: ast.stmt): src = to_source(node) diff --git a/trace_commentor/handlers/__init__.py b/trace_commentor/handlers/__init__.py index 5abfd4c..8b4a20e 100644 --- a/trace_commentor/handlers/__init__.py +++ b/trace_commentor/handlers/__init__.py @@ -1,4 +1,4 @@ -from .definitions import FunctionDef +from .definitions import FunctionDef, Return from .statements import Pass, Assign from .expressions import Expr, BinOp, Call from .literals import Constant diff --git a/trace_commentor/handlers/definitions.py b/trace_commentor/handlers/definitions.py index fcf7dfd..2b4f24c 100644 --- a/trace_commentor/handlers/definitions.py +++ b/trace_commentor/handlers/definitions.py @@ -15,3 +15,8 @@ def FunctionDef(self, cmtor): cmtor.append_source() cmtor.indent -= flags.INDENT + + +def Return(self, cmtor): + cmtor.process(self.value) + cmtor._return = cmtor.eval(self.value, format=False)