diff --git a/conda/env.yml b/conda/env.yml index bd3c40b..a6603fd 100644 Binary files a/conda/env.yml and b/conda/env.yml differ diff --git a/tests/test_definitions.py b/tests/test_definitions.py index c7777af..ae89efe 100644 --- a/tests/test_definitions.py +++ b/tests/test_definitions.py @@ -3,7 +3,7 @@ from test_utils import * def test_function_def(): - @Commentor() + @Commentor("") def target(): pass diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 38d672c..7ef6de8 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -4,7 +4,7 @@ from test_utils import asserteq_or_print def test_binop(): - @Commentor() + @Commentor("") def target(): 1 + 1 @@ -20,7 +20,7 @@ def test_binop(): def test_binop_cascade(): - @Commentor() + @Commentor("") def target(): 1 + 1 + 1 @@ -33,3 +33,19 @@ def test_binop_cascade(): 3 : 1 + 1 + 1 """ ''') + + +def test_call_print(): + + @Commentor("") + def target(): + print("This line will be printed.") + + asserteq_or_print( + target(), ''' + def target(): + print('This line will be printed.') + """ + None : print('This line will be printed.') + """ +''') diff --git a/tests/test_literals.py b/tests/test_literals.py index 1c95cfb..3bd5ce3 100644 --- a/tests/test_literals.py +++ b/tests/test_literals.py @@ -3,12 +3,11 @@ from test_utils import * def test_constant(): - @Commentor() + @Commentor("") def target(): 1 - - asserteq_or_print(target(), - ''' + + asserteq_or_print(target(), ''' def target(): 1 ''') diff --git a/tests/test_torch.py b/tests/test_torch.py deleted file mode 100644 index 744ce3f..0000000 --- a/tests/test_torch.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch -import torch.nn as nn - -from trace_commentor.parser import * - - -def test_func_no_arg(): - - @analyse - def target(): - - x = torch.ones(4, 5) - for i in range(3): - x = x[..., None, :] - - a = torch.randn(309, 110, 3)[:100] - f = nn.Linear(3, 128) - b = f(a.reshape(-1, 3)).reshape(309, 110, 128) - c = torch.concat((a, b), dim=-1) - - return c.flatten() - - print() - target() - - -def test_for_loop(): - - @analyse - def target(): - a = 1 - for i in range(3): - a += 1 - print(a) - - print() - target() diff --git a/tests/test_variables.py b/tests/test_variables.py new file mode 100644 index 0000000..422f2f5 --- /dev/null +++ b/tests/test_variables.py @@ -0,0 +1,20 @@ +from test_utils import * + + +def test_assign(): + + @Commentor("") + def target(): + myint = 7 + print(myint) + + asserteq_or_print( + target(), ''' + def target(): + myint = 7 + print(myint) + """ + 7 : myint + None : print(myint) + """ +''') diff --git a/trace_commentor/commentor.py b/trace_commentor/commentor.py index a45a2e5..14702b9 100644 --- a/trace_commentor/commentor.py +++ b/trace_commentor/commentor.py @@ -1,5 +1,7 @@ -import inspect import ast +import inspect +import sys +import rich from inspect import getfullargspec from functools import wraps @@ -12,13 +14,15 @@ from .utils import sign, to_source class Commentor(object): - def __init__(self, _formatters=[]) -> None: + def __init__(self, output="", fmt=[]) -> None: self._locals = dict() self._globals = dict() - self._formatters = _formatters + formatters.LIST + self._return = None + self._formatters = fmt + formatters.LIST self._lines = [] self.indent = 0 self.state = flags.SOURCE + self.file = output def __call__(self, func): @@ -28,14 +32,30 @@ class Commentor(object): self.root = ast.parse(unindented_source).body[0] if flags.DEBUG: - with open("test.log", "wt") as f: + with open("debug.log", "wt") as f: print(ast.dump(self.root, indent=4), file=f) @wraps(func) def proxy_func(*args, **kwargs): + # input { self._locals = kwargs + # } + self.process(self.root) - return "\n".join(self._lines) + + # output { + code = "\n".join(self._lines) + if self.file == "": + return code + elif self.file == "": + rich.print(code, file=sys.stderr) + elif self.file == "": + rich.print(code, file=sys.stdout) + else: + with open(self.file, "wt") as f: + rich.print(code, file=f) + return self._return + # } return proxy_func @@ -51,6 +71,10 @@ class Commentor(object): obj = eval(src, self._globals, self._locals) fmt = self.get_formatter(obj) return f"{fmt(obj)} : {src}" + + def exec(self, node: ast.stmt): + src = to_source(node) + exec(src, self._globals, self._locals) def get_formatter(self, obj): for typ, fmt in self._formatters: diff --git a/trace_commentor/handlers/__init__.py b/trace_commentor/handlers/__init__.py index 977533e..5abfd4c 100644 --- a/trace_commentor/handlers/__init__.py +++ b/trace_commentor/handlers/__init__.py @@ -1,4 +1,5 @@ from .definitions import FunctionDef -from .statements import Pass -from .expressions import Expr, BinOp +from .statements import Pass, Assign +from .expressions import Expr, BinOp, Call from .literals import Constant +from .variables import Name diff --git a/trace_commentor/handlers/expressions.py b/trace_commentor/handlers/expressions.py index 01b64e8..82e81ee 100644 --- a/trace_commentor/handlers/expressions.py +++ b/trace_commentor/handlers/expressions.py @@ -5,3 +5,10 @@ def BinOp(self, cmtor): cmtor.process(self.left) cmtor.process(self.right) cmtor.append_comment(cmtor.eval(self)) + +def Call(self, cmtor): + for arg in self.args: + cmtor.process(arg) + for kwarg in self.keywords: + cmtor.process(kwarg) + cmtor.append_comment(cmtor.eval(self)) diff --git a/trace_commentor/handlers/statements.py b/trace_commentor/handlers/statements.py index cf7c449..7b850be 100644 --- a/trace_commentor/handlers/statements.py +++ b/trace_commentor/handlers/statements.py @@ -1,2 +1,6 @@ def Pass(self, cmtor): pass + +def Assign(self, cmtor): + cmtor.process(self.value) + cmtor.exec(self) diff --git a/trace_commentor/handlers/variables.py b/trace_commentor/handlers/variables.py new file mode 100644 index 0000000..b377a67 --- /dev/null +++ b/trace_commentor/handlers/variables.py @@ -0,0 +1,2 @@ +def Name(self, cmtor): + cmtor.append_comment(cmtor.eval(self))