Lambda(), AnnAssign(), IfExp(), Set(), Dict(), FormattedValue(), JoinedStr(); some beautification.

This commit is contained in:
Yuyao Huang 2024-04-23 20:46:52 +08:00
parent 1b4eb16d51
commit f448993c0d
16 changed files with 192 additions and 214 deletions

View File

@ -43,12 +43,15 @@ def test_if():
if x > 3: # False
x = 2 * x # skipped
y = 1 # skipped
elif x > 2: # False
x = 4 * x # skipped
y = 2 # skipped
elif x > 3: # False
x = 4 * x # skipped
y = 3 # skipped
else: # True
x = 8 * x
"""
@ -57,7 +60,9 @@ def test_if():
----------
16 : x
"""
y = 5
''')
@ -81,133 +86,20 @@ def test_for():
odds = []
for x in range(10):
###### !new iteration! ######
"""
0 : __REG__for_loop_iter_once
----------
0 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
1 : __REG__for_loop_iter_once
----------
1 : x
"""
# if x % 2 == 0: # False
# continue # skipped
# odds.append(x)
"""
[] : odds
1 : x
None : odds.append(x)
"""
###### !new iteration! ######
"""
2 : __REG__for_loop_iter_once
----------
2 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
3 : __REG__for_loop_iter_once
----------
3 : x
"""
# if x % 2 == 0: # False
# continue # skipped
# odds.append(x)
"""
[1] : odds
3 : x
None : odds.append(x)
"""
###### !new iteration! ######
"""
4 : __REG__for_loop_iter_once
----------
4 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
5 : __REG__for_loop_iter_once
----------
5 : x
"""
# if x % 2 == 0: # False
# continue # skipped
# odds.append(x)
"""
[1, 3] : odds
5 : x
None : odds.append(x)
"""
###### !new iteration! ######
"""
6 : __REG__for_loop_iter_once
----------
6 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
7 : __REG__for_loop_iter_once
----------
7 : x
"""
# if x % 2 == 0: # False
# continue # skipped
# odds.append(x)
"""
[1, 3, 5] : odds
7 : x
None : odds.append(x)
"""
###### !new iteration! ######
"""
8 : __REG__for_loop_iter_once
----------
8 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
9 : __REG__for_loop_iter_once
----------
9 : x
"""
if x % 2 == 0: # False
continue # skipped
odds.append(x)
"""
[1, 3, 5, 7] : odds
9 : x
None : odds.append(x)
"""
return odds
"""
[1, 3, 5, 7, 9] : odds
"""''')
"""
''')

View File

@ -53,6 +53,7 @@ def test_args():
5 : c
2 : k
"""
return a + k
"""
1 : a

View File

@ -21,87 +21,60 @@ def test_torch():
return c.flatten()
asserteq_or_print(
target(), ''' def target():
target(), '''
def target():
x = torch.ones(4, 5)
"""
[4, 5] : torch.ones(4, 5)
Tensor((4, 5), f32) : torch.ones(4, 5)
----------
[4, 5] : x
Tensor((4, 5), f32) : x
"""
for i in range(3):
###### !new iteration! ######
"""
0 : __REG__for_loop_iter_once
----------
0 : i
"""
# x = x[..., None, :]
"""
[4, 5] : x
[4, 1, 5] : x[..., None, :]
----------
[4, 1, 5] : x
"""
###### !new iteration! ######
"""
1 : __REG__for_loop_iter_once
----------
1 : i
"""
# x = x[..., None, :]
"""
[4, 1, 5] : x
[4, 1, 1, 5] : x[..., None, :]
----------
[4, 1, 1, 5] : x
"""
###### !new iteration! ######
"""
2 : __REG__for_loop_iter_once
----------
2 : i
"""
x = x[..., None, :]
"""
[4, 1, 1, 5] : x
[4, 1, 1, 1, 5] : x[..., None, :]
Tensor((4, 1, 1, 5), f32) : x
Tensor((4, 1, 1, 1, 5), f32) : x[..., None, :]
----------
[4, 1, 1, 1, 5] : x
Tensor((4, 1, 1, 1, 5), f32) : x
"""
a = torch.randn(309, 110, 3)[:100]
"""
[309, 110, 3] : torch.randn(309, 110, 3)
[100, 110, 3] : torch.randn(309, 110, 3)[:100]
Tensor((309, 110, 3), f32) : torch.randn(309, 110, 3)
Tensor((100, 110, 3), f32) : torch.randn(309, 110, 3)[:100]
----------
[100, 110, 3] : a
Tensor((100, 110, 3), f32) : a
"""
f = nn.Linear(3, 128)
"""
----------
"""
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
"""
[100, 110, 3] : a
[11000, 3] : a.reshape(-1, 3)
[11000, 128] : f(a.reshape(-1, 3))
[100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128)
Tensor((100, 110, 3), f32) : a
Tensor((11000, 3), f32) : a.reshape(-1, 3)
Tensor((11000, 128), f32) : f(a.reshape(-1, 3))
Tensor((100, 110, 128), f32) : f(a.reshape(-1, ... pe(-1, 110, 128)
----------
[100, 110, 128] : b
Tensor((100, 110, 128), f32) : b
"""
c = torch.concat((a, b), dim=-1)
"""
[100, 110, 3] : a
[100, 110, 128] : b
[100, 110, 131] : torch.concat((a, b), dim=-1)
Tensor((100, 110, 3), f32) : a
Tensor((100, 110, 128), f32) : b
Tensor((100, 110, 131), f32) : torch.concat((a, b), dim=-1)
----------
[100, 110, 131] : c
Tensor((100, 110, 131), f32) : c
"""
return c.flatten()
"""
[100, 110, 131] : c
[1441000] : c.flatten()
Tensor((100, 110, 131), f32) : c
Tensor((1441000,), f32) : c.flatten()
"""
''')

View File

@ -1,11 +1,14 @@
import re
from io import StringIO
from contextlib import closing
from trace_commentor import flags, Commentor
WS = re.compile(" +")
def asserteq_or_print(value, ground_truth):
if flags.DEBUG or flags.PRINT:
print(value)
else:
value = value.strip("\n").rstrip(" ")
ground_truth = ground_truth.strip("\n").rstrip(" ")
value = re.sub(WS, " ", value.strip("\n").rstrip(" ").rstrip("\n"))
ground_truth = re.sub(WS, " ", ground_truth.strip("\n").rstrip(" ").rstrip("\n"))
assert value == ground_truth, "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth, ">>>>>>>> GROUND\n"])

View File

@ -1 +1,2 @@
from .commentor import Commentor
from .formatters import silent

View File

@ -1,5 +1,6 @@
import ast
import inspect
import re
from inspect import getfullargspec
from functools import wraps
@ -10,6 +11,9 @@ from . import flags
from .utils import sign, to_source, comment_to_file
NEWLINE = re.compile(" *\n *")
class Commentor(object):
def __init__(self, output="<stderr>", fmt=[], check=True, _exit=True) -> None:
@ -29,6 +33,12 @@ class Commentor(object):
def __call__(self, func):
raw_lines, start_lineno = inspect.getsourcelines(func)
with open(inspect.getfile(func)) as f:
all_lines = f.readlines()
other_lines = dict(
before="".join(all_lines[:start_lineno-1]),
after="".join(all_lines[start_lineno+len(raw_lines):])
)
self.indent = len(raw_lines[0]) - len(raw_lines[0].lstrip())
unindented_source = ''.join([l[self.indent:] for l in raw_lines])
self.root = ast.parse(unindented_source).body[0]
@ -71,8 +81,8 @@ class Commentor(object):
self.process(self.root)
# output {
comments = "\n".join(self._lines)
if comment_to_file(comments, file=self.file):
comments = "\n".join([l for l in self._lines if l is not None])
if comment_to_file(comments, file=self.file, **other_lines):
if self._exit:
exit(0)
return self._return
@ -89,12 +99,15 @@ class Commentor(object):
raise NotImplementedError(f"Unknown how to handle {node_type} node.")
return handler(node, self, *args, **kwargs)
def to_source(self, node):
return to_source(node, self.indent)
def eval(self, node: ast.Expr, format=True):
src = node if type(node) == str else to_source(node)
src = node if type(node) == str else self.to_source(node)
try:
obj = eval(src, self._globals, self._locals)
except Exception as e:
e.add_note(f"\tduring evaluating `{src}`")
e.add_note(f"\tduring evaluating `{src}` translated from `{ast.dump(node)}`")
raise e
if not format:
return obj
@ -102,16 +115,26 @@ class Commentor(object):
fmt = self.get_formatter(obj)
fmt_obj = fmt(obj)
if fmt_obj is not None:
return f"{fmt(obj)} : {src}"
fmt_obj = str(fmt_obj).replace("\n", " ")
if len(fmt_obj) > flags.MAX_FMT_LEN:
fmt_obj = fmt_obj[:flags.MAX_FMT_LEN - 5 - 10] + " ... " + fmt_obj[-10:]
else:
fmt_obj = fmt_obj
src = re.sub(NEWLINE, " ", src)
if len(src) > flags.MAX_EXPR_LEN:
src = src[:flags.MAX_EXPR_LEN//2 - 2] + " ... " + src[-flags.MAX_EXPR_LEN//2+3:]
return f"{fmt_obj} : {src}"
def exec(self, node: ast.stmt):
src = to_source(node)
src = self.to_source(node)
exec(src, self._globals, self._locals)
def get_formatter(self, obj):
for typ, fmt in self._formatters:
if isinstance(typ, type) and isinstance(obj, typ):
return fmt
elif isinstance(typ, (list, tuple)) and len(typ) and isinstance(typ[0], type) and isinstance(obj, typ):
return fmt
elif not isinstance(typ, type) and callable(typ) and typ(obj):
return fmt
else:
@ -127,7 +150,7 @@ class Commentor(object):
def append_source(self, line=None):
if self.state == flags.COMMENT:
self.__append('"""')
self.__append('"""\n')
self.state = flags.SOURCE
if line is not None:
return self.__append(sign(line, 2))

View File

@ -3,6 +3,8 @@ import os
bool_env = lambda name: os.environ.get(name, "false").lower() in ('true', '1', 'yes')
MAX_EXPR_LEN = 37
MAX_FMT_LEN = 45
DEBUG = bool_env("DEBUG")
PRINT = bool_env("PRINT")
INDENT = 4

View File

@ -1,17 +1,18 @@
import types
from .desc import desc
def silent(_):
return None
def Tensor(tensor):
return "Tensor" in tensor.__class__.__name__
def fmt_Tensor(tensor):
return f"{list(tensor.shape)}"
LIST = [
(callable, silent),
(Tensor, fmt_Tensor),
(Tensor, desc),
(types.ModuleType, silent),
((list, dict), desc),
]

View File

@ -0,0 +1,46 @@
def _apply(x, f):
if isinstance(x, (list)):
return x.__class__([_apply(xi, f) for xi in x])
if not isinstance(x, dict):
if _has_method(x, "to_dict"):
x = x.to_dict()
if _has_method(x, "as_dict"):
x = x.as_dict()
if _has_method(x, "items"):
x = dict(x)
if isinstance(x, dict):
return {k: _apply(v, f) for k, v in x.items()}
try:
return f(x)
except TypeError:
return _print(f"<Unknown Type: {x.__class__.__name__}>")
except Exception as e:
return _print(f"<{e}>")
def desc(x):
def _desc(_x):
if "Tensor" in _x.__class__.__name__:
dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
return _print(f"Tensor({tuple(_x.shape)}, {dtype})")
if isinstance(_x, (int, float, str)):
return _print(_x)
if _x is None:
return "None"
raise TypeError
return str(_apply(x, _desc))
class _print(str):
def __repr__(self) -> str:
return self
def _has_method(obj, methodname) -> bool:
return getattr(getattr(obj, methodname, None), "__call__", False) is not False

View File

@ -1,6 +1,6 @@
from .definitions import FunctionDef, Return
from .statements import Pass, Assign
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword
from .literals import Constant, Tuple, List
from .definitions import FunctionDef, Return, Lambda
from .statements import Pass, Assign, AnnAssign
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword, IfExp
from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr
from .variables import Name
from .control_flow import If, For, Continue, Break

View File

@ -1,6 +1,5 @@
import ast
from .. import flags
from ..utils import to_source
ELIF = 2
PASS = 4
@ -18,20 +17,20 @@ def If(self, cmtor, state=0):
state = state | PASS
if state & ELIF:
cmtor.append_source(f"elif {to_source(self.test)}: # {test_comment}")
cmtor.append_source(f"elif {cmtor.to_source(self.test)}: # {test_comment}")
else:
cmtor.append_source(f"if {to_source(self.test)}: # {test_comment}")
cmtor.append_source(f"if {cmtor.to_source(self.test)}: # {test_comment}")
cmtor.indent += flags.INDENT
for stmt in self.body:
if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt))
cmtor.append_source(cmtor.to_source(stmt))
if test:
cmtor.process(stmt)
test = cmtor._stack_event == flags.NORMAL
else:
cmtor._lines[-1] += " # skipped"
cmtor.append_source()
cmtor.append_source("")
cmtor.indent -= flags.INDENT
if self.orelse:
@ -45,16 +44,16 @@ def If(self, cmtor, state=0):
cmtor.indent += flags.INDENT
for stmt in self.orelse:
if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt))
cmtor.append_source(cmtor.to_source(stmt))
if test:
cmtor.process(stmt)
test = cmtor._stack_event == flags.NORMAL
cmtor.append_source()
cmtor.append_source("")
cmtor.indent -= flags.INDENT
def For(self, cmtor):
cmtor.append_source(to_source(ast.For(self.target, self.iter, [], [])))
cmtor.append_source(cmtor.to_source(ast.For(self.target, self.iter, [], [])))
loop_start: int = cmtor.next_line()
@ -68,15 +67,15 @@ def For(self, cmtor):
# enter new iteration (mantain locals())
cmtor.append_source("")
cmtor.append_source("###### !new iteration! ######")
# cmtor.append_source("###### !new iteration! ######")
cmtor._locals[REG_it] = it
stmt = ast.Assign([self.target], ast.Name(REG_it, ast.Load()))
cmtor.process(stmt)
cmtor.exec(stmt)
# process body
for stmt in self.body:
if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt))
cmtor.append_source(cmtor.to_source(stmt))
if cmtor._stack_event == flags.NORMAL:
cmtor.process(stmt)
else:
@ -92,12 +91,13 @@ def For(self, cmtor):
cmtor.indent -= flags.INDENT
# comment out all code except for the last iter
# delete all code except for the last iter
for lineno in range(loop_start, last_iter_start):
if cmtor._lines_category[lineno][0] == flags.SOURCE:
line: str = cmtor._lines[lineno]
if line.lstrip() and line.lstrip()[0] != "#":
cmtor._lines[lineno] = " " * self_indent + "# " + line[self_indent:]
cmtor._lines[lineno] = None
# if cmtor._lines_category[lineno][0] == flags.SOURCE:
# line: str = cmtor._lines[lineno]
# if line.lstrip() and line.lstrip()[0] != "#":
# cmtor._lines[lineno] = " " * self_indent + "# " + line[self_indent:]
def Break(self, cmtor):

View File

@ -1,8 +1,7 @@
from .. import flags
from ..utils import to_source
def FunctionDef(self, cmtor):
cmtor.append_source(f"def {self.name}({to_source(self.args)}):")
cmtor.append_source(f"def {self.name}({cmtor.to_source(self.args)}):")
cmtor.indent += flags.INDENT
if self is cmtor.root:
@ -13,7 +12,7 @@ def FunctionDef(self, cmtor):
for stmt in self.body:
if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt))
cmtor.append_source(cmtor.to_source(stmt))
if self is cmtor.root:
cmtor.process(stmt)
@ -26,3 +25,7 @@ def FunctionDef(self, cmtor):
def Return(self, cmtor):
cmtor.process(self.value)
cmtor._return = cmtor.eval(self.value, format=False)
def Lambda(self, cmtor):
pass

View File

@ -55,3 +55,8 @@ def Slice(self, cmtor):
def keyword(self, cmtor):
cmtor.process(self.value)
def IfExp(self, cmtor):
cmtor.append_comment(cmtor.eval(self.test))
cmtor.append_comment(cmtor.eval(self))

View File

@ -10,3 +10,21 @@ def Tuple(self, cmtor):
def List(self, cmtor):
for x in self.elts:
cmtor.process(x)
def Set(self, cmtor):
for x in self.elts:
cmtor.process(x)
def Dict(self, cmtor):
for k, v in zip(self.keys, self.values):
cmtor.process(v)
def FormattedValue(self, cmtor):
pass
def JoinedStr(self, cmtor):
cmtor.append_comment(cmtor.eval(self))

View File

@ -6,7 +6,16 @@ def Pass(self, cmtor):
def Assign(self, cmtor):
cmtor.process(self.value)
cmtor.exec(self)
if type(self.value) not in flags.ASSIGN_SILENT:
if type(self.value) not in flags.ASSIGN_SILENT and len(self.targets):
cmtor.append_comment(f"----------")
for target in self.targets:
cmtor.process(target)
def AnnAssign(self, cmtor):
if getattr(self, "value", None) is not None:
cmtor.process(self.value)
cmtor.exec(self)
if type(self.value) not in flags.ASSIGN_SILENT:
cmtor.append_comment(f"----------")
cmtor.process(self.target)

View File

@ -19,10 +19,11 @@ def sign(line: str, depth=1):
return line
def to_source(node):
def to_source(node, indent=0):
src = astor.to_source(node).rstrip("\n")
if type(node) != ast.Tuple and len(src) > 2 and src[0] == "(" and src[-1] == ")":
src = src[1:-1]
src = src.replace("\n", "\n" + " " * indent)
return src
@ -30,7 +31,7 @@ def dump(node, file=sys.stderr):
print(ast.dump(node, indent=4), file=file)
def comment_to_file(code, file: str) -> bool:
def comment_to_file(code, file: str, before="", after="") -> bool:
if file == "<return>":
return False
elif file == "<stderr>":
@ -38,18 +39,18 @@ def comment_to_file(code, file: str) -> bool:
syntax = Syntax(code, "python")
rich.print(syntax, file=sys.stderr)
else:
rich.print(code, file=sys.stderr)
print(before + code + after, file=sys.stderr)
elif file == "<stdout>":
if sys.stdout.isatty():
syntax = Syntax(code, "python")
rich.print(syntax, file=sys.stdout)
else:
rich.print(code, file=sys.stdout)
print(before + code + after, file=sys.stdout)
elif isinstance(file, IOBase):
rich.print(code, file=file)
print(code, file=file)
elif type(file) == str:
with open(file, "wt") as f:
rich.print(code, file=f)
print(before + code + after, file=f)
else:
raise NotImplementedError(f"Unknown file protocal {file}")
return True