This commit is contained in:
Yuyao Huang 2024-04-19 18:51:21 +08:00
parent b86e2e5613
commit 6e8c8e1998
6 changed files with 62 additions and 5 deletions

15
experiment.py Normal file
View File

@ -0,0 +1,15 @@
from tests.test_utils import *
@Commentor(fmt=[
(type(None), lambda o: None),
])
def function():
mystring = 'hello'
print(mystring)
mystring = "hello"
print(mystring)
return 1
print(function())

View File

@ -11,3 +11,26 @@ def test_function_def():
def target():
pass
''')
def test_return():
with closing(StringIO()) as f:
@Commentor(f)
def target():
a = 1
return a + 1
assert target() == 2
asserteq_or_print(
f.getvalue(), '''
def target():
a = 1
return a + 1
"""
1 : a
2 : a + 1
"""
''')

View File

@ -1,7 +1,11 @@
from io import StringIO
from contextlib import closing
from trace_commentor import flags, Commentor
def asserteq_or_print(value, ground_truth):
if flags.DEBUG or flags.PRINT:
print(value)
else:
assert value == ground_truth.strip("\n"), "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth.strip("\n"), ">>>>>>>> GROUND\n"])
value = value.strip("\n").rstrip(" ")
ground_truth = ground_truth.strip("\n").rstrip(" ")
assert value == ground_truth, "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth, ">>>>>>>> GROUND\n"])

View File

@ -5,6 +5,7 @@ import rich
from inspect import getfullargspec
from functools import wraps
from io import IOBase
from . import handlers
from . import formatters
@ -51,9 +52,13 @@ class Commentor(object):
rich.print(code, file=sys.stderr)
elif self.file == "<stdout>":
rich.print(code, file=sys.stdout)
else:
elif isinstance(self.file, IOBase):
rich.print(code, file=self.file)
elif type(self.file) == str:
with open(self.file, "wt") as f:
rich.print(code, file=f)
else:
raise NotImplementedError(f"Unknown file protocal {self.file}")
return self._return
# }
@ -66,10 +71,15 @@ class Commentor(object):
raise NotImplementedError(f"Unknown how to handle {node_type} node.")
return handler(node, self)
def eval(self, node: ast.Expr):
def eval(self, node: ast.Expr, format=True):
src = to_source(node)
obj = eval(src, self._globals, self._locals)
if not format:
return obj
fmt = self.get_formatter(obj)
fmt_obj = fmt(obj)
if fmt_obj is not None:
return f"{fmt(obj)} : {src}"
def exec(self, node: ast.stmt):

View File

@ -1,4 +1,4 @@
from .definitions import FunctionDef
from .definitions import FunctionDef, Return
from .statements import Pass, Assign
from .expressions import Expr, BinOp, Call
from .literals import Constant

View File

@ -15,3 +15,8 @@ def FunctionDef(self, cmtor):
cmtor.append_source()
cmtor.indent -= flags.INDENT
def Return(self, cmtor):
cmtor.process(self.value)
cmtor._return = cmtor.eval(self.value, format=False)