Assign();Name();Call()

This commit is contained in:
Yuyao Huang 2024-04-19 18:08:34 +08:00
parent 5017ce2736
commit b86e2e5613
11 changed files with 87 additions and 51 deletions

Binary file not shown.

View File

@ -3,7 +3,7 @@ from test_utils import *
def test_function_def(): def test_function_def():
@Commentor() @Commentor("<return>")
def target(): def target():
pass pass

View File

@ -4,7 +4,7 @@ from test_utils import asserteq_or_print
def test_binop(): def test_binop():
@Commentor() @Commentor("<return>")
def target(): def target():
1 + 1 1 + 1
@ -20,7 +20,7 @@ def test_binop():
def test_binop_cascade(): def test_binop_cascade():
@Commentor() @Commentor("<return>")
def target(): def target():
1 + 1 + 1 1 + 1 + 1
@ -33,3 +33,19 @@ def test_binop_cascade():
3 : 1 + 1 + 1 3 : 1 + 1 + 1
""" """
''') ''')
def test_call_print():
@Commentor("<return>")
def target():
print("This line will be printed.")
asserteq_or_print(
target(), '''
def target():
print('This line will be printed.')
"""
None : print('This line will be printed.')
"""
''')

View File

@ -3,12 +3,11 @@ from test_utils import *
def test_constant(): def test_constant():
@Commentor() @Commentor("<return>")
def target(): def target():
1 1
asserteq_or_print(target(), asserteq_or_print(target(), '''
'''
def target(): def target():
1 1
''') ''')

View File

@ -1,37 +0,0 @@
import torch
import torch.nn as nn
from trace_commentor.parser import *
def test_func_no_arg():
@analyse
def target():
x = torch.ones(4, 5)
for i in range(3):
x = x[..., None, :]
a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(309, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
print()
target()
def test_for_loop():
@analyse
def target():
a = 1
for i in range(3):
a += 1
print(a)
print()
target()

20
tests/test_variables.py Normal file
View File

@ -0,0 +1,20 @@
from test_utils import *
def test_assign():
@Commentor("<return>")
def target():
myint = 7
print(myint)
asserteq_or_print(
target(), '''
def target():
myint = 7
print(myint)
"""
7 : myint
None : print(myint)
"""
''')

View File

@ -1,5 +1,7 @@
import inspect
import ast import ast
import inspect
import sys
import rich
from inspect import getfullargspec from inspect import getfullargspec
from functools import wraps from functools import wraps
@ -12,13 +14,15 @@ from .utils import sign, to_source
class Commentor(object): class Commentor(object):
def __init__(self, _formatters=[]) -> None: def __init__(self, output="<stderr>", fmt=[]) -> None:
self._locals = dict() self._locals = dict()
self._globals = dict() self._globals = dict()
self._formatters = _formatters + formatters.LIST self._return = None
self._formatters = fmt + formatters.LIST
self._lines = [] self._lines = []
self.indent = 0 self.indent = 0
self.state = flags.SOURCE self.state = flags.SOURCE
self.file = output
def __call__(self, func): def __call__(self, func):
@ -28,14 +32,30 @@ class Commentor(object):
self.root = ast.parse(unindented_source).body[0] self.root = ast.parse(unindented_source).body[0]
if flags.DEBUG: if flags.DEBUG:
with open("test.log", "wt") as f: with open("debug.log", "wt") as f:
print(ast.dump(self.root, indent=4), file=f) print(ast.dump(self.root, indent=4), file=f)
@wraps(func) @wraps(func)
def proxy_func(*args, **kwargs): def proxy_func(*args, **kwargs):
# input {
self._locals = kwargs self._locals = kwargs
# }
self.process(self.root) self.process(self.root)
return "\n".join(self._lines)
# output {
code = "\n".join(self._lines)
if self.file == "<return>":
return code
elif self.file == "<stderr>":
rich.print(code, file=sys.stderr)
elif self.file == "<stdout>":
rich.print(code, file=sys.stdout)
else:
with open(self.file, "wt") as f:
rich.print(code, file=f)
return self._return
# }
return proxy_func return proxy_func
@ -51,6 +71,10 @@ class Commentor(object):
obj = eval(src, self._globals, self._locals) obj = eval(src, self._globals, self._locals)
fmt = self.get_formatter(obj) fmt = self.get_formatter(obj)
return f"{fmt(obj)} : {src}" return f"{fmt(obj)} : {src}"
def exec(self, node: ast.stmt):
src = to_source(node)
exec(src, self._globals, self._locals)
def get_formatter(self, obj): def get_formatter(self, obj):
for typ, fmt in self._formatters: for typ, fmt in self._formatters:

View File

@ -1,4 +1,5 @@
from .definitions import FunctionDef from .definitions import FunctionDef
from .statements import Pass from .statements import Pass, Assign
from .expressions import Expr, BinOp from .expressions import Expr, BinOp, Call
from .literals import Constant from .literals import Constant
from .variables import Name

View File

@ -5,3 +5,10 @@ def BinOp(self, cmtor):
cmtor.process(self.left) cmtor.process(self.left)
cmtor.process(self.right) cmtor.process(self.right)
cmtor.append_comment(cmtor.eval(self)) cmtor.append_comment(cmtor.eval(self))
def Call(self, cmtor):
for arg in self.args:
cmtor.process(arg)
for kwarg in self.keywords:
cmtor.process(kwarg)
cmtor.append_comment(cmtor.eval(self))

View File

@ -1,2 +1,6 @@
def Pass(self, cmtor): def Pass(self, cmtor):
pass pass
def Assign(self, cmtor):
cmtor.process(self.value)
cmtor.exec(self)

View File

@ -0,0 +1,2 @@
def Name(self, cmtor):
cmtor.append_comment(cmtor.eval(self))