diff --git a/experiment.py b/experiment.py index 9d217eb..6d93872 100644 --- a/experiment.py +++ b/experiment.py @@ -1,12 +1,26 @@ from tests.test_utils import * -@Commentor(fmt=[ - (type(None), lambda o: None), - -]) -def function(): - x = 2 - print(x == 2) + +def test(): + + @Commentor() + def target(): + x = 2 + if x > 3: + x = 2 * x + y = 1 + elif x > 2: + x = 4 * x + y = 2 + elif x > 3: + x = 4 * x + y = 3 + else: + x = 8 * x + y = 5 -print(function()) + print(target()) + + +test() diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 7d8a543..3ab78c1 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -17,3 +17,45 @@ def test_constant(): None : print(x == 2) """ ''') + + +def test(): + + @Commentor("") + def target(): + x = 2 + if x > 3: + x = 2 * x + y = 1 + elif x > 2: + x = 4 * x + y = 2 + elif x > 3: + x = 4 * x + y = 3 + else: + x = 8 * x + y = 5 + + asserteq_or_print(target(), ''' + def target(): + x = 2 + if x > 3: # False + x = 2 * x + y = 1 + elif x > 2: # False + x = 4 * x + y = 2 + elif x > 3: # False + x = 4 * x + y = 3 + else: # True + x = 8 * x + """ + 2 : x + 16 : 8 * x + ---------- + 16 : x + """ + y = 5 +''') diff --git a/tests/test_literals.py b/tests/test_literals.py index 746c969..31a0134 100644 --- a/tests/test_literals.py +++ b/tests/test_literals.py @@ -23,7 +23,7 @@ def test_tuple(): def target(): a, b = 1, 2 """ - ======== + ---------- 1 : a 2 : b """ diff --git a/trace_commentor/commentor.py b/trace_commentor/commentor.py index 0d92adc..8cbb45c 100644 --- a/trace_commentor/commentor.py +++ b/trace_commentor/commentor.py @@ -64,12 +64,12 @@ class Commentor(object): return proxy_func - def process(self, node: ast.AST): + def process(self, node: ast.AST, *args, **kwargs): node_type = node.__class__.__name__ handler = getattr(handlers, node_type, None) if handler is None: raise NotImplementedError(f"Unknown how to handle {node_type} node.") - return handler(node, self) + return handler(node, self, *args, **kwargs) def eval(self, node: ast.Expr, format=True): src = to_source(node) diff --git a/trace_commentor/handlers/__init__.py b/trace_commentor/handlers/__init__.py index fc28924..3ce26e1 100644 --- a/trace_commentor/handlers/__init__.py +++ b/trace_commentor/handlers/__init__.py @@ -3,3 +3,4 @@ from .statements import Pass, Assign from .expressions import Expr, BinOp, Call, Compare from .literals import Constant, Tuple from .variables import Name +from .control_flow import If diff --git a/trace_commentor/handlers/control_flow.py b/trace_commentor/handlers/control_flow.py new file mode 100644 index 0000000..ba6b902 --- /dev/null +++ b/trace_commentor/handlers/control_flow.py @@ -0,0 +1,50 @@ +import ast +from .. import flags +from ..utils import to_source, APPEND_SOURCE_BY_THEMSELVES + + +ELIF = 2 +PASS = 4 + + +def If(self, cmtor, state=0): + + if state & PASS: + test = False + test_comment = "skipped" + else: + test = cmtor.eval(self.test, format=False) + test_comment = test + if test: + state = state | PASS + + if state & ELIF: + cmtor.append_source(f"elif {to_source(self.test)}: # {test_comment}") + else: + cmtor.append_source(f"if {to_source(self.test)}: # {test_comment}") + + cmtor.indent += flags.INDENT + for stmt in self.body: + if type(stmt) not in APPEND_SOURCE_BY_THEMSELVES: + cmtor.append_source(to_source(stmt)) + if test: + cmtor.process(stmt) + cmtor.append_source() + cmtor.indent -= flags.INDENT + + if self.orelse: + if type(self.orelse[0]) == ast.If: + cmtor.process(self.orelse[0], state=state | ELIF) + else: + test = not (state & PASS) + test_comment = True if test else "skipped" + cmtor.append_source(f"else: # {test_comment}") + + cmtor.indent += flags.INDENT + for stmt in self.orelse: + if type(stmt) not in APPEND_SOURCE_BY_THEMSELVES: + cmtor.append_source(to_source(stmt)) + if test: + cmtor.process(stmt) + cmtor.append_source() + cmtor.indent -= flags.INDENT diff --git a/trace_commentor/handlers/definitions.py b/trace_commentor/handlers/definitions.py index 2b4f24c..334e151 100644 --- a/trace_commentor/handlers/definitions.py +++ b/trace_commentor/handlers/definitions.py @@ -1,5 +1,6 @@ +import ast from .. import flags -from ..utils import to_source +from ..utils import to_source, APPEND_SOURCE_BY_THEMSELVES def FunctionDef(self, cmtor): cmtor.append_source(f"def {self.name}():") @@ -7,7 +8,8 @@ def FunctionDef(self, cmtor): for stmt in self.body: - cmtor.append_source(to_source(stmt)) + if type(stmt) not in APPEND_SOURCE_BY_THEMSELVES: + cmtor.append_source(to_source(stmt)) if self is cmtor.root: cmtor.process(stmt) diff --git a/trace_commentor/handlers/statements.py b/trace_commentor/handlers/statements.py index a780e19..649efdb 100644 --- a/trace_commentor/handlers/statements.py +++ b/trace_commentor/handlers/statements.py @@ -8,6 +8,6 @@ def Assign(self, cmtor): cmtor.process(self.value) cmtor.exec(self) if type(self.value) not in [ast.Constant]: - cmtor.append_comment("========") + cmtor.append_comment("----------") for target in self.targets: cmtor.process(target) diff --git a/trace_commentor/utils.py b/trace_commentor/utils.py index 37da372..27e77c4 100644 --- a/trace_commentor/utils.py +++ b/trace_commentor/utils.py @@ -1,8 +1,15 @@ -import os +import ast import astor import inspect +import os from . import flags + +APPEND_SOURCE_BY_THEMSELVES = [ + ast.If, +] + + def sign(line: str, depth=1): if flags.DEBUG: currentframe = inspect.currentframe() @@ -12,5 +19,9 @@ def sign(line: str, depth=1): else: return line + def to_source(node): - return astor.to_source(node).rstrip("\n") + src = astor.to_source(node).rstrip("\n") + if type(node) != ast.Tuple and len(src) > 2 and src[0] == "(" and src[-1] == ")": + src = src[1:-1] + return src