Assign();Name();Call()

This commit is contained in:
Yuyao Huang 2024-04-19 18:08:34 +08:00
parent 5017ce2736
commit b86e2e5613
11 changed files with 87 additions and 51 deletions

Binary file not shown.

View File

@ -3,7 +3,7 @@ from test_utils import *
def test_function_def():
@Commentor()
@Commentor("<return>")
def target():
pass

View File

@ -4,7 +4,7 @@ from test_utils import asserteq_or_print
def test_binop():
@Commentor()
@Commentor("<return>")
def target():
1 + 1
@ -20,7 +20,7 @@ def test_binop():
def test_binop_cascade():
@Commentor()
@Commentor("<return>")
def target():
1 + 1 + 1
@ -33,3 +33,19 @@ def test_binop_cascade():
3 : 1 + 1 + 1
"""
''')
def test_call_print():
@Commentor("<return>")
def target():
print("This line will be printed.")
asserteq_or_print(
target(), '''
def target():
print('This line will be printed.')
"""
None : print('This line will be printed.')
"""
''')

View File

@ -3,12 +3,11 @@ from test_utils import *
def test_constant():
@Commentor()
@Commentor("<return>")
def target():
1
asserteq_or_print(target(),
'''
asserteq_or_print(target(), '''
def target():
1
''')

View File

@ -1,37 +0,0 @@
import torch
import torch.nn as nn
from trace_commentor.parser import *
def test_func_no_arg():
@analyse
def target():
x = torch.ones(4, 5)
for i in range(3):
x = x[..., None, :]
a = torch.randn(309, 110, 3)[:100]
f = nn.Linear(3, 128)
b = f(a.reshape(-1, 3)).reshape(309, 110, 128)
c = torch.concat((a, b), dim=-1)
return c.flatten()
print()
target()
def test_for_loop():
@analyse
def target():
a = 1
for i in range(3):
a += 1
print(a)
print()
target()

20
tests/test_variables.py Normal file
View File

@ -0,0 +1,20 @@
from test_utils import *
def test_assign():
@Commentor("<return>")
def target():
myint = 7
print(myint)
asserteq_or_print(
target(), '''
def target():
myint = 7
print(myint)
"""
7 : myint
None : print(myint)
"""
''')

View File

@ -1,5 +1,7 @@
import inspect
import ast
import inspect
import sys
import rich
from inspect import getfullargspec
from functools import wraps
@ -12,13 +14,15 @@ from .utils import sign, to_source
class Commentor(object):
def __init__(self, _formatters=[]) -> None:
def __init__(self, output="<stderr>", fmt=[]) -> None:
self._locals = dict()
self._globals = dict()
self._formatters = _formatters + formatters.LIST
self._return = None
self._formatters = fmt + formatters.LIST
self._lines = []
self.indent = 0
self.state = flags.SOURCE
self.file = output
def __call__(self, func):
@ -28,14 +32,30 @@ class Commentor(object):
self.root = ast.parse(unindented_source).body[0]
if flags.DEBUG:
with open("test.log", "wt") as f:
with open("debug.log", "wt") as f:
print(ast.dump(self.root, indent=4), file=f)
@wraps(func)
def proxy_func(*args, **kwargs):
# input {
self._locals = kwargs
# }
self.process(self.root)
return "\n".join(self._lines)
# output {
code = "\n".join(self._lines)
if self.file == "<return>":
return code
elif self.file == "<stderr>":
rich.print(code, file=sys.stderr)
elif self.file == "<stdout>":
rich.print(code, file=sys.stdout)
else:
with open(self.file, "wt") as f:
rich.print(code, file=f)
return self._return
# }
return proxy_func
@ -52,6 +72,10 @@ class Commentor(object):
fmt = self.get_formatter(obj)
return f"{fmt(obj)} : {src}"
def exec(self, node: ast.stmt):
src = to_source(node)
exec(src, self._globals, self._locals)
def get_formatter(self, obj):
for typ, fmt in self._formatters:
if isinstance(obj, typ):

View File

@ -1,4 +1,5 @@
from .definitions import FunctionDef
from .statements import Pass
from .expressions import Expr, BinOp
from .statements import Pass, Assign
from .expressions import Expr, BinOp, Call
from .literals import Constant
from .variables import Name

View File

@ -5,3 +5,10 @@ def BinOp(self, cmtor):
cmtor.process(self.left)
cmtor.process(self.right)
cmtor.append_comment(cmtor.eval(self))
def Call(self, cmtor):
for arg in self.args:
cmtor.process(arg)
for kwarg in self.keywords:
cmtor.process(kwarg)
cmtor.append_comment(cmtor.eval(self))

View File

@ -1,2 +1,6 @@
def Pass(self, cmtor):
pass
def Assign(self, cmtor):
cmtor.process(self.value)
cmtor.exec(self)

View File

@ -0,0 +1,2 @@
def Name(self, cmtor):
cmtor.append_comment(cmtor.eval(self))