This commit is contained in:
Yuyao Huang 2024-04-20 02:01:26 +08:00
parent 245a927382
commit 49651168e3
9 changed files with 136 additions and 16 deletions

View File

@ -1,12 +1,26 @@
from tests.test_utils import *
@Commentor(fmt=[
(type(None), lambda o: None),
])
def function():
x = 2
print(x == 2)
def test():
@Commentor()
def target():
x = 2
if x > 3:
x = 2 * x
y = 1
elif x > 2:
x = 4 * x
y = 2
elif x > 3:
x = 4 * x
y = 3
else:
x = 8 * x
y = 5
print(function())
print(target())
test()

View File

@ -17,3 +17,45 @@ def test_constant():
None : print(x == 2)
"""
''')
def test():
@Commentor("<return>")
def target():
x = 2
if x > 3:
x = 2 * x
y = 1
elif x > 2:
x = 4 * x
y = 2
elif x > 3:
x = 4 * x
y = 3
else:
x = 8 * x
y = 5
asserteq_or_print(target(), '''
def target():
x = 2
if x > 3: # False
x = 2 * x
y = 1
elif x > 2: # False
x = 4 * x
y = 2
elif x > 3: # False
x = 4 * x
y = 3
else: # True
x = 8 * x
"""
2 : x
16 : 8 * x
----------
16 : x
"""
y = 5
''')

View File

@ -23,7 +23,7 @@ def test_tuple():
def target():
a, b = 1, 2
"""
========
----------
1 : a
2 : b
"""

View File

@ -64,12 +64,12 @@ class Commentor(object):
return proxy_func
def process(self, node: ast.AST):
def process(self, node: ast.AST, *args, **kwargs):
node_type = node.__class__.__name__
handler = getattr(handlers, node_type, None)
if handler is None:
raise NotImplementedError(f"Unknown how to handle {node_type} node.")
return handler(node, self)
return handler(node, self, *args, **kwargs)
def eval(self, node: ast.Expr, format=True):
src = to_source(node)

View File

@ -3,3 +3,4 @@ from .statements import Pass, Assign
from .expressions import Expr, BinOp, Call, Compare
from .literals import Constant, Tuple
from .variables import Name
from .control_flow import If

View File

@ -0,0 +1,50 @@
import ast
from .. import flags
from ..utils import to_source, APPEND_SOURCE_BY_THEMSELVES
ELIF = 2
PASS = 4
def If(self, cmtor, state=0):
if state & PASS:
test = False
test_comment = "skipped"
else:
test = cmtor.eval(self.test, format=False)
test_comment = test
if test:
state = state | PASS
if state & ELIF:
cmtor.append_source(f"elif {to_source(self.test)}: # {test_comment}")
else:
cmtor.append_source(f"if {to_source(self.test)}: # {test_comment}")
cmtor.indent += flags.INDENT
for stmt in self.body:
if type(stmt) not in APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt))
if test:
cmtor.process(stmt)
cmtor.append_source()
cmtor.indent -= flags.INDENT
if self.orelse:
if type(self.orelse[0]) == ast.If:
cmtor.process(self.orelse[0], state=state | ELIF)
else:
test = not (state & PASS)
test_comment = True if test else "skipped"
cmtor.append_source(f"else: # {test_comment}")
cmtor.indent += flags.INDENT
for stmt in self.orelse:
if type(stmt) not in APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt))
if test:
cmtor.process(stmt)
cmtor.append_source()
cmtor.indent -= flags.INDENT

View File

@ -1,5 +1,6 @@
import ast
from .. import flags
from ..utils import to_source
from ..utils import to_source, APPEND_SOURCE_BY_THEMSELVES
def FunctionDef(self, cmtor):
cmtor.append_source(f"def {self.name}():")
@ -7,7 +8,8 @@ def FunctionDef(self, cmtor):
for stmt in self.body:
cmtor.append_source(to_source(stmt))
if type(stmt) not in APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt))
if self is cmtor.root:
cmtor.process(stmt)

View File

@ -8,6 +8,6 @@ def Assign(self, cmtor):
cmtor.process(self.value)
cmtor.exec(self)
if type(self.value) not in [ast.Constant]:
cmtor.append_comment("========")
cmtor.append_comment("----------")
for target in self.targets:
cmtor.process(target)

View File

@ -1,8 +1,15 @@
import os
import ast
import astor
import inspect
import os
from . import flags
APPEND_SOURCE_BY_THEMSELVES = [
ast.If,
]
def sign(line: str, depth=1):
if flags.DEBUG:
currentframe = inspect.currentframe()
@ -12,5 +19,9 @@ def sign(line: str, depth=1):
else:
return line
def to_source(node):
return astor.to_source(node).rstrip("\n")
src = astor.to_source(node).rstrip("\n")
if type(node) != ast.Tuple and len(src) > 2 and src[0] == "(" and src[-1] == ")":
src = src[1:-1]
return src