diff --git a/tests/test_control_flow.py b/tests/test_control_flow.py index 4ce60f7..2c0c71e 100644 --- a/tests/test_control_flow.py +++ b/tests/test_control_flow.py @@ -43,12 +43,15 @@ def test_if(): if x > 3: # False x = 2 * x # skipped y = 1 # skipped + elif x > 2: # False x = 4 * x # skipped y = 2 # skipped + elif x > 3: # False x = 4 * x # skipped y = 3 # skipped + else: # True x = 8 * x """ @@ -57,7 +60,9 @@ def test_if(): ---------- 16 : x """ + y = 5 + ''') @@ -81,133 +86,20 @@ def test_for(): odds = [] for x in range(10): - ###### !new iteration! ###### - """ - 0 : __REG__for_loop_iter_once - ---------- - 0 : x - """ - # if x % 2 == 0: # True - # continue # True - # odds.append(x) # skipped - - ###### !new iteration! ###### - """ - 1 : __REG__for_loop_iter_once - ---------- - 1 : x - """ - # if x % 2 == 0: # False - # continue # skipped - # odds.append(x) - """ - [] : odds - 1 : x - None : odds.append(x) - """ - - ###### !new iteration! ###### - """ - 2 : __REG__for_loop_iter_once - ---------- - 2 : x - """ - # if x % 2 == 0: # True - # continue # True - # odds.append(x) # skipped - - ###### !new iteration! ###### - """ - 3 : __REG__for_loop_iter_once - ---------- - 3 : x - """ - # if x % 2 == 0: # False - # continue # skipped - # odds.append(x) - """ - [1] : odds - 3 : x - None : odds.append(x) - """ - - ###### !new iteration! ###### - """ - 4 : __REG__for_loop_iter_once - ---------- - 4 : x - """ - # if x % 2 == 0: # True - # continue # True - # odds.append(x) # skipped - - ###### !new iteration! ###### - """ - 5 : __REG__for_loop_iter_once - ---------- - 5 : x - """ - # if x % 2 == 0: # False - # continue # skipped - # odds.append(x) - """ - [1, 3] : odds - 5 : x - None : odds.append(x) - """ - - ###### !new iteration! ###### - """ - 6 : __REG__for_loop_iter_once - ---------- - 6 : x - """ - # if x % 2 == 0: # True - # continue # True - # odds.append(x) # skipped - - ###### !new iteration! ###### - """ - 7 : __REG__for_loop_iter_once - ---------- - 7 : x - """ - # if x % 2 == 0: # False - # continue # skipped - # odds.append(x) - """ - [1, 3, 5] : odds - 7 : x - None : odds.append(x) - """ - - ###### !new iteration! ###### - """ - 8 : __REG__for_loop_iter_once - ---------- - 8 : x - """ - # if x % 2 == 0: # True - # continue # True - # odds.append(x) # skipped - - ###### !new iteration! ###### - """ - 9 : __REG__for_loop_iter_once - ---------- - 9 : x - """ if x % 2 == 0: # False continue # skipped + odds.append(x) """ [1, 3, 5, 7] : odds 9 : x None : odds.append(x) """ + return odds """ [1, 3, 5, 7, 9] : odds - """''') + """ +''') diff --git a/tests/test_definitions.py b/tests/test_definitions.py index ac46b57..afeff9f 100644 --- a/tests/test_definitions.py +++ b/tests/test_definitions.py @@ -53,6 +53,7 @@ def test_args(): 5 : c 2 : k """ + return a + k """ 1 : a diff --git a/tests/test_torch.py b/tests/test_torch.py index 26eebaf..b2eedfa 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -21,87 +21,60 @@ def test_torch(): return c.flatten() asserteq_or_print( - target(), ''' def target(): + target(), ''' + def target(): x = torch.ones(4, 5) """ - [4, 5] : torch.ones(4, 5) + Tensor((4, 5), f32) : torch.ones(4, 5) ---------- - [4, 5] : x + Tensor((4, 5), f32) : x """ + for i in range(3): - ###### !new iteration! ###### - """ - 0 : __REG__for_loop_iter_once - ---------- - 0 : i - """ - # x = x[..., None, :] - """ - [4, 5] : x - [4, 1, 5] : x[..., None, :] - ---------- - [4, 1, 5] : x - """ - - ###### !new iteration! ###### - """ - 1 : __REG__for_loop_iter_once - ---------- - 1 : i - """ - # x = x[..., None, :] - """ - [4, 1, 5] : x - [4, 1, 1, 5] : x[..., None, :] - ---------- - [4, 1, 1, 5] : x - """ - - ###### !new iteration! ###### - """ - 2 : __REG__for_loop_iter_once - ---------- - 2 : i - """ x = x[..., None, :] """ - [4, 1, 1, 5] : x - [4, 1, 1, 1, 5] : x[..., None, :] + Tensor((4, 1, 1, 5), f32) : x + Tensor((4, 1, 1, 1, 5), f32) : x[..., None, :] ---------- - [4, 1, 1, 1, 5] : x + Tensor((4, 1, 1, 1, 5), f32) : x """ + a = torch.randn(309, 110, 3)[:100] """ - [309, 110, 3] : torch.randn(309, 110, 3) - [100, 110, 3] : torch.randn(309, 110, 3)[:100] + Tensor((309, 110, 3), f32) : torch.randn(309, 110, 3) + Tensor((100, 110, 3), f32) : torch.randn(309, 110, 3)[:100] ---------- - [100, 110, 3] : a + Tensor((100, 110, 3), f32) : a """ + f = nn.Linear(3, 128) """ ---------- """ + b = f(a.reshape(-1, 3)).reshape(-1, 110, 128) """ - [100, 110, 3] : a - [11000, 3] : a.reshape(-1, 3) - [11000, 128] : f(a.reshape(-1, 3)) - [100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128) + Tensor((100, 110, 3), f32) : a + Tensor((11000, 3), f32) : a.reshape(-1, 3) + Tensor((11000, 128), f32) : f(a.reshape(-1, 3)) + Tensor((100, 110, 128), f32) : f(a.reshape(-1, ... pe(-1, 110, 128) ---------- - [100, 110, 128] : b + Tensor((100, 110, 128), f32) : b """ + c = torch.concat((a, b), dim=-1) """ - [100, 110, 3] : a - [100, 110, 128] : b - [100, 110, 131] : torch.concat((a, b), dim=-1) + Tensor((100, 110, 3), f32) : a + Tensor((100, 110, 128), f32) : b + Tensor((100, 110, 131), f32) : torch.concat((a, b), dim=-1) ---------- - [100, 110, 131] : c + Tensor((100, 110, 131), f32) : c """ + return c.flatten() """ - [100, 110, 131] : c - [1441000] : c.flatten() + Tensor((100, 110, 131), f32) : c + Tensor((1441000,), f32) : c.flatten() """ ''') diff --git a/tests/test_utils.py b/tests/test_utils.py index 37dfb9d..634f28f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,14 @@ +import re from io import StringIO from contextlib import closing from trace_commentor import flags, Commentor +WS = re.compile(" +") + def asserteq_or_print(value, ground_truth): if flags.DEBUG or flags.PRINT: print(value) else: - value = value.strip("\n").rstrip(" ") - ground_truth = ground_truth.strip("\n").rstrip(" ") + value = re.sub(WS, " ", value.strip("\n").rstrip(" ").rstrip("\n")) + ground_truth = re.sub(WS, " ", ground_truth.strip("\n").rstrip(" ").rstrip("\n")) assert value == ground_truth, "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth, ">>>>>>>> GROUND\n"]) diff --git a/trace_commentor/__init__.py b/trace_commentor/__init__.py index 421fb93..0b66588 100644 --- a/trace_commentor/__init__.py +++ b/trace_commentor/__init__.py @@ -1 +1,2 @@ from .commentor import Commentor +from .formatters import silent diff --git a/trace_commentor/commentor.py b/trace_commentor/commentor.py index 497d434..fe63b94 100644 --- a/trace_commentor/commentor.py +++ b/trace_commentor/commentor.py @@ -1,5 +1,6 @@ import ast import inspect +import re from inspect import getfullargspec from functools import wraps @@ -10,6 +11,9 @@ from . import flags from .utils import sign, to_source, comment_to_file +NEWLINE = re.compile(" *\n *") + + class Commentor(object): def __init__(self, output="", fmt=[], check=True, _exit=True) -> None: @@ -29,6 +33,12 @@ class Commentor(object): def __call__(self, func): raw_lines, start_lineno = inspect.getsourcelines(func) + with open(inspect.getfile(func)) as f: + all_lines = f.readlines() + other_lines = dict( + before="".join(all_lines[:start_lineno-1]), + after="".join(all_lines[start_lineno+len(raw_lines):]) + ) self.indent = len(raw_lines[0]) - len(raw_lines[0].lstrip()) unindented_source = ''.join([l[self.indent:] for l in raw_lines]) self.root = ast.parse(unindented_source).body[0] @@ -71,8 +81,8 @@ class Commentor(object): self.process(self.root) # output { - comments = "\n".join(self._lines) - if comment_to_file(comments, file=self.file): + comments = "\n".join([l for l in self._lines if l is not None]) + if comment_to_file(comments, file=self.file, **other_lines): if self._exit: exit(0) return self._return @@ -88,13 +98,16 @@ class Commentor(object): if handler is None: raise NotImplementedError(f"Unknown how to handle {node_type} node.") return handler(node, self, *args, **kwargs) + + def to_source(self, node): + return to_source(node, self.indent) def eval(self, node: ast.Expr, format=True): - src = node if type(node) == str else to_source(node) + src = node if type(node) == str else self.to_source(node) try: obj = eval(src, self._globals, self._locals) except Exception as e: - e.add_note(f"\tduring evaluating `{src}`") + e.add_note(f"\tduring evaluating `{src}` translated from `{ast.dump(node)}`") raise e if not format: return obj @@ -102,16 +115,26 @@ class Commentor(object): fmt = self.get_formatter(obj) fmt_obj = fmt(obj) if fmt_obj is not None: - return f"{fmt(obj)} : {src}" + fmt_obj = str(fmt_obj).replace("\n", " ") + if len(fmt_obj) > flags.MAX_FMT_LEN: + fmt_obj = fmt_obj[:flags.MAX_FMT_LEN - 5 - 10] + " ... " + fmt_obj[-10:] + else: + fmt_obj = fmt_obj + src = re.sub(NEWLINE, " ", src) + if len(src) > flags.MAX_EXPR_LEN: + src = src[:flags.MAX_EXPR_LEN//2 - 2] + " ... " + src[-flags.MAX_EXPR_LEN//2+3:] + return f"{fmt_obj} : {src}" def exec(self, node: ast.stmt): - src = to_source(node) + src = self.to_source(node) exec(src, self._globals, self._locals) def get_formatter(self, obj): for typ, fmt in self._formatters: if isinstance(typ, type) and isinstance(obj, typ): return fmt + elif isinstance(typ, (list, tuple)) and len(typ) and isinstance(typ[0], type) and isinstance(obj, typ): + return fmt elif not isinstance(typ, type) and callable(typ) and typ(obj): return fmt else: @@ -127,7 +150,7 @@ class Commentor(object): def append_source(self, line=None): if self.state == flags.COMMENT: - self.__append('"""') + self.__append('"""\n') self.state = flags.SOURCE if line is not None: return self.__append(sign(line, 2)) diff --git a/trace_commentor/flags.py b/trace_commentor/flags.py index ee42a8f..ec5837d 100644 --- a/trace_commentor/flags.py +++ b/trace_commentor/flags.py @@ -3,6 +3,8 @@ import os bool_env = lambda name: os.environ.get(name, "false").lower() in ('true', '1', 'yes') +MAX_EXPR_LEN = 37 +MAX_FMT_LEN = 45 DEBUG = bool_env("DEBUG") PRINT = bool_env("PRINT") INDENT = 4 diff --git a/trace_commentor/formatters/__init__.py b/trace_commentor/formatters/__init__.py index 7661074..2fe6c06 100644 --- a/trace_commentor/formatters/__init__.py +++ b/trace_commentor/formatters/__init__.py @@ -1,17 +1,18 @@ import types +from .desc import desc + def silent(_): return None + def Tensor(tensor): return "Tensor" in tensor.__class__.__name__ -def fmt_Tensor(tensor): - return f"{list(tensor.shape)}" - LIST = [ (callable, silent), - (Tensor, fmt_Tensor), + (Tensor, desc), (types.ModuleType, silent), + ((list, dict), desc), ] diff --git a/trace_commentor/formatters/desc.py b/trace_commentor/formatters/desc.py new file mode 100644 index 0000000..374b95c --- /dev/null +++ b/trace_commentor/formatters/desc.py @@ -0,0 +1,46 @@ +def _apply(x, f): + + if isinstance(x, (list)): + return x.__class__([_apply(xi, f) for xi in x]) + + if not isinstance(x, dict): + if _has_method(x, "to_dict"): + x = x.to_dict() + if _has_method(x, "as_dict"): + x = x.as_dict() + if _has_method(x, "items"): + x = dict(x) + + if isinstance(x, dict): + return {k: _apply(v, f) for k, v in x.items()} + + try: + return f(x) + except TypeError: + return _print(f"") + except Exception as e: + return _print(f"<{e}>") + + +def desc(x): + def _desc(_x): + if "Tensor" in _x.__class__.__name__: + dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i") + return _print(f"Tensor({tuple(_x.shape)}, {dtype})") + if isinstance(_x, (int, float, str)): + return _print(_x) + if _x is None: + return "None" + raise TypeError + + return str(_apply(x, _desc)) + + +class _print(str): + + def __repr__(self) -> str: + return self + + +def _has_method(obj, methodname) -> bool: + return getattr(getattr(obj, methodname, None), "__call__", False) is not False diff --git a/trace_commentor/handlers/__init__.py b/trace_commentor/handlers/__init__.py index 7b97318..e5acb47 100644 --- a/trace_commentor/handlers/__init__.py +++ b/trace_commentor/handlers/__init__.py @@ -1,6 +1,6 @@ -from .definitions import FunctionDef, Return -from .statements import Pass, Assign -from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword -from .literals import Constant, Tuple, List +from .definitions import FunctionDef, Return, Lambda +from .statements import Pass, Assign, AnnAssign +from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword, IfExp +from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr from .variables import Name from .control_flow import If, For, Continue, Break diff --git a/trace_commentor/handlers/control_flow.py b/trace_commentor/handlers/control_flow.py index 696716c..ce36e88 100644 --- a/trace_commentor/handlers/control_flow.py +++ b/trace_commentor/handlers/control_flow.py @@ -1,6 +1,5 @@ import ast from .. import flags -from ..utils import to_source ELIF = 2 PASS = 4 @@ -18,20 +17,20 @@ def If(self, cmtor, state=0): state = state | PASS if state & ELIF: - cmtor.append_source(f"elif {to_source(self.test)}: # {test_comment}") + cmtor.append_source(f"elif {cmtor.to_source(self.test)}: # {test_comment}") else: - cmtor.append_source(f"if {to_source(self.test)}: # {test_comment}") + cmtor.append_source(f"if {cmtor.to_source(self.test)}: # {test_comment}") cmtor.indent += flags.INDENT for stmt in self.body: if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES: - cmtor.append_source(to_source(stmt)) + cmtor.append_source(cmtor.to_source(stmt)) if test: cmtor.process(stmt) test = cmtor._stack_event == flags.NORMAL else: cmtor._lines[-1] += " # skipped" - cmtor.append_source() + cmtor.append_source("") cmtor.indent -= flags.INDENT if self.orelse: @@ -45,16 +44,16 @@ def If(self, cmtor, state=0): cmtor.indent += flags.INDENT for stmt in self.orelse: if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES: - cmtor.append_source(to_source(stmt)) + cmtor.append_source(cmtor.to_source(stmt)) if test: cmtor.process(stmt) test = cmtor._stack_event == flags.NORMAL - cmtor.append_source() + cmtor.append_source("") cmtor.indent -= flags.INDENT def For(self, cmtor): - cmtor.append_source(to_source(ast.For(self.target, self.iter, [], []))) + cmtor.append_source(cmtor.to_source(ast.For(self.target, self.iter, [], []))) loop_start: int = cmtor.next_line() @@ -68,15 +67,15 @@ def For(self, cmtor): # enter new iteration (mantain locals()) cmtor.append_source("") - cmtor.append_source("###### !new iteration! ######") + # cmtor.append_source("###### !new iteration! ######") cmtor._locals[REG_it] = it stmt = ast.Assign([self.target], ast.Name(REG_it, ast.Load())) - cmtor.process(stmt) + cmtor.exec(stmt) # process body for stmt in self.body: if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES: - cmtor.append_source(to_source(stmt)) + cmtor.append_source(cmtor.to_source(stmt)) if cmtor._stack_event == flags.NORMAL: cmtor.process(stmt) else: @@ -92,12 +91,13 @@ def For(self, cmtor): cmtor.indent -= flags.INDENT - # comment out all code except for the last iter + # delete all code except for the last iter for lineno in range(loop_start, last_iter_start): - if cmtor._lines_category[lineno][0] == flags.SOURCE: - line: str = cmtor._lines[lineno] - if line.lstrip() and line.lstrip()[0] != "#": - cmtor._lines[lineno] = " " * self_indent + "# " + line[self_indent:] + cmtor._lines[lineno] = None + # if cmtor._lines_category[lineno][0] == flags.SOURCE: + # line: str = cmtor._lines[lineno] + # if line.lstrip() and line.lstrip()[0] != "#": + # cmtor._lines[lineno] = " " * self_indent + "# " + line[self_indent:] def Break(self, cmtor): diff --git a/trace_commentor/handlers/definitions.py b/trace_commentor/handlers/definitions.py index 88034db..ca2e0f5 100644 --- a/trace_commentor/handlers/definitions.py +++ b/trace_commentor/handlers/definitions.py @@ -1,8 +1,7 @@ from .. import flags -from ..utils import to_source def FunctionDef(self, cmtor): - cmtor.append_source(f"def {self.name}({to_source(self.args)}):") + cmtor.append_source(f"def {self.name}({cmtor.to_source(self.args)}):") cmtor.indent += flags.INDENT if self is cmtor.root: @@ -13,7 +12,7 @@ def FunctionDef(self, cmtor): for stmt in self.body: if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES: - cmtor.append_source(to_source(stmt)) + cmtor.append_source(cmtor.to_source(stmt)) if self is cmtor.root: cmtor.process(stmt) @@ -26,3 +25,7 @@ def FunctionDef(self, cmtor): def Return(self, cmtor): cmtor.process(self.value) cmtor._return = cmtor.eval(self.value, format=False) + + +def Lambda(self, cmtor): + pass diff --git a/trace_commentor/handlers/expressions.py b/trace_commentor/handlers/expressions.py index 864142f..8abea14 100644 --- a/trace_commentor/handlers/expressions.py +++ b/trace_commentor/handlers/expressions.py @@ -55,3 +55,8 @@ def Slice(self, cmtor): def keyword(self, cmtor): cmtor.process(self.value) + + +def IfExp(self, cmtor): + cmtor.append_comment(cmtor.eval(self.test)) + cmtor.append_comment(cmtor.eval(self)) diff --git a/trace_commentor/handlers/literals.py b/trace_commentor/handlers/literals.py index b94e268..28fa08d 100644 --- a/trace_commentor/handlers/literals.py +++ b/trace_commentor/handlers/literals.py @@ -10,3 +10,21 @@ def Tuple(self, cmtor): def List(self, cmtor): for x in self.elts: cmtor.process(x) + + +def Set(self, cmtor): + for x in self.elts: + cmtor.process(x) + + +def Dict(self, cmtor): + for k, v in zip(self.keys, self.values): + cmtor.process(v) + + +def FormattedValue(self, cmtor): + pass + + +def JoinedStr(self, cmtor): + cmtor.append_comment(cmtor.eval(self)) diff --git a/trace_commentor/handlers/statements.py b/trace_commentor/handlers/statements.py index 34bcc81..164c88d 100644 --- a/trace_commentor/handlers/statements.py +++ b/trace_commentor/handlers/statements.py @@ -6,7 +6,16 @@ def Pass(self, cmtor): def Assign(self, cmtor): cmtor.process(self.value) cmtor.exec(self) - if type(self.value) not in flags.ASSIGN_SILENT: + if type(self.value) not in flags.ASSIGN_SILENT and len(self.targets): cmtor.append_comment(f"----------") for target in self.targets: cmtor.process(target) + + +def AnnAssign(self, cmtor): + if getattr(self, "value", None) is not None: + cmtor.process(self.value) + cmtor.exec(self) + if type(self.value) not in flags.ASSIGN_SILENT: + cmtor.append_comment(f"----------") + cmtor.process(self.target) diff --git a/trace_commentor/utils.py b/trace_commentor/utils.py index 3b62ad7..98b5306 100644 --- a/trace_commentor/utils.py +++ b/trace_commentor/utils.py @@ -19,10 +19,11 @@ def sign(line: str, depth=1): return line -def to_source(node): +def to_source(node, indent=0): src = astor.to_source(node).rstrip("\n") if type(node) != ast.Tuple and len(src) > 2 and src[0] == "(" and src[-1] == ")": src = src[1:-1] + src = src.replace("\n", "\n" + " " * indent) return src @@ -30,7 +31,7 @@ def dump(node, file=sys.stderr): print(ast.dump(node, indent=4), file=file) -def comment_to_file(code, file: str) -> bool: +def comment_to_file(code, file: str, before="", after="") -> bool: if file == "": return False elif file == "": @@ -38,18 +39,18 @@ def comment_to_file(code, file: str) -> bool: syntax = Syntax(code, "python") rich.print(syntax, file=sys.stderr) else: - rich.print(code, file=sys.stderr) + print(before + code + after, file=sys.stderr) elif file == "": if sys.stdout.isatty(): syntax = Syntax(code, "python") rich.print(syntax, file=sys.stdout) else: - rich.print(code, file=sys.stdout) + print(before + code + after, file=sys.stdout) elif isinstance(file, IOBase): - rich.print(code, file=file) + print(code, file=file) elif type(file) == str: with open(file, "wt") as f: - rich.print(code, file=f) + print(before + code + after, file=f) else: raise NotImplementedError(f"Unknown file protocal {file}") return True