Lambda(), AnnAssign(), IfExp(), Set(), Dict(), FormattedValue(), JoinedStr(); some beautification.

This commit is contained in:
Yuyao Huang 2024-04-23 20:46:52 +08:00
parent 1b4eb16d51
commit f448993c0d
16 changed files with 192 additions and 214 deletions

View File

@ -43,12 +43,15 @@ def test_if():
if x > 3: # False if x > 3: # False
x = 2 * x # skipped x = 2 * x # skipped
y = 1 # skipped y = 1 # skipped
elif x > 2: # False elif x > 2: # False
x = 4 * x # skipped x = 4 * x # skipped
y = 2 # skipped y = 2 # skipped
elif x > 3: # False elif x > 3: # False
x = 4 * x # skipped x = 4 * x # skipped
y = 3 # skipped y = 3 # skipped
else: # True else: # True
x = 8 * x x = 8 * x
""" """
@ -57,7 +60,9 @@ def test_if():
---------- ----------
16 : x 16 : x
""" """
y = 5 y = 5
''') ''')
@ -81,133 +86,20 @@ def test_for():
odds = [] odds = []
for x in range(10): for x in range(10):
###### !new iteration! ######
"""
0 : __REG__for_loop_iter_once
----------
0 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
1 : __REG__for_loop_iter_once
----------
1 : x
"""
# if x % 2 == 0: # False
# continue # skipped
# odds.append(x)
"""
[] : odds
1 : x
None : odds.append(x)
"""
###### !new iteration! ######
"""
2 : __REG__for_loop_iter_once
----------
2 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
3 : __REG__for_loop_iter_once
----------
3 : x
"""
# if x % 2 == 0: # False
# continue # skipped
# odds.append(x)
"""
[1] : odds
3 : x
None : odds.append(x)
"""
###### !new iteration! ######
"""
4 : __REG__for_loop_iter_once
----------
4 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
5 : __REG__for_loop_iter_once
----------
5 : x
"""
# if x % 2 == 0: # False
# continue # skipped
# odds.append(x)
"""
[1, 3] : odds
5 : x
None : odds.append(x)
"""
###### !new iteration! ######
"""
6 : __REG__for_loop_iter_once
----------
6 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
7 : __REG__for_loop_iter_once
----------
7 : x
"""
# if x % 2 == 0: # False
# continue # skipped
# odds.append(x)
"""
[1, 3, 5] : odds
7 : x
None : odds.append(x)
"""
###### !new iteration! ######
"""
8 : __REG__for_loop_iter_once
----------
8 : x
"""
# if x % 2 == 0: # True
# continue # True
# odds.append(x) # skipped
###### !new iteration! ######
"""
9 : __REG__for_loop_iter_once
----------
9 : x
"""
if x % 2 == 0: # False if x % 2 == 0: # False
continue # skipped continue # skipped
odds.append(x) odds.append(x)
""" """
[1, 3, 5, 7] : odds [1, 3, 5, 7] : odds
9 : x 9 : x
None : odds.append(x) None : odds.append(x)
""" """
return odds return odds
""" """
[1, 3, 5, 7, 9] : odds [1, 3, 5, 7, 9] : odds
"""''') """
''')

View File

@ -53,6 +53,7 @@ def test_args():
5 : c 5 : c
2 : k 2 : k
""" """
return a + k return a + k
""" """
1 : a 1 : a

View File

@ -21,87 +21,60 @@ def test_torch():
return c.flatten() return c.flatten()
asserteq_or_print( asserteq_or_print(
target(), ''' def target(): target(), '''
def target():
x = torch.ones(4, 5) x = torch.ones(4, 5)
""" """
[4, 5] : torch.ones(4, 5) Tensor((4, 5), f32) : torch.ones(4, 5)
---------- ----------
[4, 5] : x Tensor((4, 5), f32) : x
""" """
for i in range(3): for i in range(3):
###### !new iteration! ######
"""
0 : __REG__for_loop_iter_once
----------
0 : i
"""
# x = x[..., None, :]
"""
[4, 5] : x
[4, 1, 5] : x[..., None, :]
----------
[4, 1, 5] : x
"""
###### !new iteration! ######
"""
1 : __REG__for_loop_iter_once
----------
1 : i
"""
# x = x[..., None, :]
"""
[4, 1, 5] : x
[4, 1, 1, 5] : x[..., None, :]
----------
[4, 1, 1, 5] : x
"""
###### !new iteration! ######
"""
2 : __REG__for_loop_iter_once
----------
2 : i
"""
x = x[..., None, :] x = x[..., None, :]
""" """
[4, 1, 1, 5] : x Tensor((4, 1, 1, 5), f32) : x
[4, 1, 1, 1, 5] : x[..., None, :] Tensor((4, 1, 1, 1, 5), f32) : x[..., None, :]
---------- ----------
[4, 1, 1, 1, 5] : x Tensor((4, 1, 1, 1, 5), f32) : x
""" """
a = torch.randn(309, 110, 3)[:100] a = torch.randn(309, 110, 3)[:100]
""" """
[309, 110, 3] : torch.randn(309, 110, 3) Tensor((309, 110, 3), f32) : torch.randn(309, 110, 3)
[100, 110, 3] : torch.randn(309, 110, 3)[:100] Tensor((100, 110, 3), f32) : torch.randn(309, 110, 3)[:100]
---------- ----------
[100, 110, 3] : a Tensor((100, 110, 3), f32) : a
""" """
f = nn.Linear(3, 128) f = nn.Linear(3, 128)
""" """
---------- ----------
""" """
b = f(a.reshape(-1, 3)).reshape(-1, 110, 128) b = f(a.reshape(-1, 3)).reshape(-1, 110, 128)
""" """
[100, 110, 3] : a Tensor((100, 110, 3), f32) : a
[11000, 3] : a.reshape(-1, 3) Tensor((11000, 3), f32) : a.reshape(-1, 3)
[11000, 128] : f(a.reshape(-1, 3)) Tensor((11000, 128), f32) : f(a.reshape(-1, 3))
[100, 110, 128] : f(a.reshape(-1, 3)).reshape(-1, 110, 128) Tensor((100, 110, 128), f32) : f(a.reshape(-1, ... pe(-1, 110, 128)
---------- ----------
[100, 110, 128] : b Tensor((100, 110, 128), f32) : b
""" """
c = torch.concat((a, b), dim=-1) c = torch.concat((a, b), dim=-1)
""" """
[100, 110, 3] : a Tensor((100, 110, 3), f32) : a
[100, 110, 128] : b Tensor((100, 110, 128), f32) : b
[100, 110, 131] : torch.concat((a, b), dim=-1) Tensor((100, 110, 131), f32) : torch.concat((a, b), dim=-1)
---------- ----------
[100, 110, 131] : c Tensor((100, 110, 131), f32) : c
""" """
return c.flatten() return c.flatten()
""" """
[100, 110, 131] : c Tensor((100, 110, 131), f32) : c
[1441000] : c.flatten() Tensor((1441000,), f32) : c.flatten()
""" """
''') ''')

View File

@ -1,11 +1,14 @@
import re
from io import StringIO from io import StringIO
from contextlib import closing from contextlib import closing
from trace_commentor import flags, Commentor from trace_commentor import flags, Commentor
WS = re.compile(" +")
def asserteq_or_print(value, ground_truth): def asserteq_or_print(value, ground_truth):
if flags.DEBUG or flags.PRINT: if flags.DEBUG or flags.PRINT:
print(value) print(value)
else: else:
value = value.strip("\n").rstrip(" ") value = re.sub(WS, " ", value.strip("\n").rstrip(" ").rstrip("\n"))
ground_truth = ground_truth.strip("\n").rstrip(" ") ground_truth = re.sub(WS, " ", ground_truth.strip("\n").rstrip(" ").rstrip("\n"))
assert value == ground_truth, "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth, ">>>>>>>> GROUND\n"]) assert value == ground_truth, "\n".join(["\n\n<<<<<<<< VALUE", value, "========================", ground_truth, ">>>>>>>> GROUND\n"])

View File

@ -1 +1,2 @@
from .commentor import Commentor from .commentor import Commentor
from .formatters import silent

View File

@ -1,5 +1,6 @@
import ast import ast
import inspect import inspect
import re
from inspect import getfullargspec from inspect import getfullargspec
from functools import wraps from functools import wraps
@ -10,6 +11,9 @@ from . import flags
from .utils import sign, to_source, comment_to_file from .utils import sign, to_source, comment_to_file
NEWLINE = re.compile(" *\n *")
class Commentor(object): class Commentor(object):
def __init__(self, output="<stderr>", fmt=[], check=True, _exit=True) -> None: def __init__(self, output="<stderr>", fmt=[], check=True, _exit=True) -> None:
@ -29,6 +33,12 @@ class Commentor(object):
def __call__(self, func): def __call__(self, func):
raw_lines, start_lineno = inspect.getsourcelines(func) raw_lines, start_lineno = inspect.getsourcelines(func)
with open(inspect.getfile(func)) as f:
all_lines = f.readlines()
other_lines = dict(
before="".join(all_lines[:start_lineno-1]),
after="".join(all_lines[start_lineno+len(raw_lines):])
)
self.indent = len(raw_lines[0]) - len(raw_lines[0].lstrip()) self.indent = len(raw_lines[0]) - len(raw_lines[0].lstrip())
unindented_source = ''.join([l[self.indent:] for l in raw_lines]) unindented_source = ''.join([l[self.indent:] for l in raw_lines])
self.root = ast.parse(unindented_source).body[0] self.root = ast.parse(unindented_source).body[0]
@ -71,8 +81,8 @@ class Commentor(object):
self.process(self.root) self.process(self.root)
# output { # output {
comments = "\n".join(self._lines) comments = "\n".join([l for l in self._lines if l is not None])
if comment_to_file(comments, file=self.file): if comment_to_file(comments, file=self.file, **other_lines):
if self._exit: if self._exit:
exit(0) exit(0)
return self._return return self._return
@ -88,13 +98,16 @@ class Commentor(object):
if handler is None: if handler is None:
raise NotImplementedError(f"Unknown how to handle {node_type} node.") raise NotImplementedError(f"Unknown how to handle {node_type} node.")
return handler(node, self, *args, **kwargs) return handler(node, self, *args, **kwargs)
def to_source(self, node):
return to_source(node, self.indent)
def eval(self, node: ast.Expr, format=True): def eval(self, node: ast.Expr, format=True):
src = node if type(node) == str else to_source(node) src = node if type(node) == str else self.to_source(node)
try: try:
obj = eval(src, self._globals, self._locals) obj = eval(src, self._globals, self._locals)
except Exception as e: except Exception as e:
e.add_note(f"\tduring evaluating `{src}`") e.add_note(f"\tduring evaluating `{src}` translated from `{ast.dump(node)}`")
raise e raise e
if not format: if not format:
return obj return obj
@ -102,16 +115,26 @@ class Commentor(object):
fmt = self.get_formatter(obj) fmt = self.get_formatter(obj)
fmt_obj = fmt(obj) fmt_obj = fmt(obj)
if fmt_obj is not None: if fmt_obj is not None:
return f"{fmt(obj)} : {src}" fmt_obj = str(fmt_obj).replace("\n", " ")
if len(fmt_obj) > flags.MAX_FMT_LEN:
fmt_obj = fmt_obj[:flags.MAX_FMT_LEN - 5 - 10] + " ... " + fmt_obj[-10:]
else:
fmt_obj = fmt_obj
src = re.sub(NEWLINE, " ", src)
if len(src) > flags.MAX_EXPR_LEN:
src = src[:flags.MAX_EXPR_LEN//2 - 2] + " ... " + src[-flags.MAX_EXPR_LEN//2+3:]
return f"{fmt_obj} : {src}"
def exec(self, node: ast.stmt): def exec(self, node: ast.stmt):
src = to_source(node) src = self.to_source(node)
exec(src, self._globals, self._locals) exec(src, self._globals, self._locals)
def get_formatter(self, obj): def get_formatter(self, obj):
for typ, fmt in self._formatters: for typ, fmt in self._formatters:
if isinstance(typ, type) and isinstance(obj, typ): if isinstance(typ, type) and isinstance(obj, typ):
return fmt return fmt
elif isinstance(typ, (list, tuple)) and len(typ) and isinstance(typ[0], type) and isinstance(obj, typ):
return fmt
elif not isinstance(typ, type) and callable(typ) and typ(obj): elif not isinstance(typ, type) and callable(typ) and typ(obj):
return fmt return fmt
else: else:
@ -127,7 +150,7 @@ class Commentor(object):
def append_source(self, line=None): def append_source(self, line=None):
if self.state == flags.COMMENT: if self.state == flags.COMMENT:
self.__append('"""') self.__append('"""\n')
self.state = flags.SOURCE self.state = flags.SOURCE
if line is not None: if line is not None:
return self.__append(sign(line, 2)) return self.__append(sign(line, 2))

View File

@ -3,6 +3,8 @@ import os
bool_env = lambda name: os.environ.get(name, "false").lower() in ('true', '1', 'yes') bool_env = lambda name: os.environ.get(name, "false").lower() in ('true', '1', 'yes')
MAX_EXPR_LEN = 37
MAX_FMT_LEN = 45
DEBUG = bool_env("DEBUG") DEBUG = bool_env("DEBUG")
PRINT = bool_env("PRINT") PRINT = bool_env("PRINT")
INDENT = 4 INDENT = 4

View File

@ -1,17 +1,18 @@
import types import types
from .desc import desc
def silent(_): def silent(_):
return None return None
def Tensor(tensor): def Tensor(tensor):
return "Tensor" in tensor.__class__.__name__ return "Tensor" in tensor.__class__.__name__
def fmt_Tensor(tensor):
return f"{list(tensor.shape)}"
LIST = [ LIST = [
(callable, silent), (callable, silent),
(Tensor, fmt_Tensor), (Tensor, desc),
(types.ModuleType, silent), (types.ModuleType, silent),
((list, dict), desc),
] ]

View File

@ -0,0 +1,46 @@
def _apply(x, f):
if isinstance(x, (list)):
return x.__class__([_apply(xi, f) for xi in x])
if not isinstance(x, dict):
if _has_method(x, "to_dict"):
x = x.to_dict()
if _has_method(x, "as_dict"):
x = x.as_dict()
if _has_method(x, "items"):
x = dict(x)
if isinstance(x, dict):
return {k: _apply(v, f) for k, v in x.items()}
try:
return f(x)
except TypeError:
return _print(f"<Unknown Type: {x.__class__.__name__}>")
except Exception as e:
return _print(f"<{e}>")
def desc(x):
def _desc(_x):
if "Tensor" in _x.__class__.__name__:
dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
return _print(f"Tensor({tuple(_x.shape)}, {dtype})")
if isinstance(_x, (int, float, str)):
return _print(_x)
if _x is None:
return "None"
raise TypeError
return str(_apply(x, _desc))
class _print(str):
def __repr__(self) -> str:
return self
def _has_method(obj, methodname) -> bool:
return getattr(getattr(obj, methodname, None), "__call__", False) is not False

View File

@ -1,6 +1,6 @@
from .definitions import FunctionDef, Return from .definitions import FunctionDef, Return, Lambda
from .statements import Pass, Assign from .statements import Pass, Assign, AnnAssign
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword, IfExp
from .literals import Constant, Tuple, List from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr
from .variables import Name from .variables import Name
from .control_flow import If, For, Continue, Break from .control_flow import If, For, Continue, Break

View File

@ -1,6 +1,5 @@
import ast import ast
from .. import flags from .. import flags
from ..utils import to_source
ELIF = 2 ELIF = 2
PASS = 4 PASS = 4
@ -18,20 +17,20 @@ def If(self, cmtor, state=0):
state = state | PASS state = state | PASS
if state & ELIF: if state & ELIF:
cmtor.append_source(f"elif {to_source(self.test)}: # {test_comment}") cmtor.append_source(f"elif {cmtor.to_source(self.test)}: # {test_comment}")
else: else:
cmtor.append_source(f"if {to_source(self.test)}: # {test_comment}") cmtor.append_source(f"if {cmtor.to_source(self.test)}: # {test_comment}")
cmtor.indent += flags.INDENT cmtor.indent += flags.INDENT
for stmt in self.body: for stmt in self.body:
if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES: if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt)) cmtor.append_source(cmtor.to_source(stmt))
if test: if test:
cmtor.process(stmt) cmtor.process(stmt)
test = cmtor._stack_event == flags.NORMAL test = cmtor._stack_event == flags.NORMAL
else: else:
cmtor._lines[-1] += " # skipped" cmtor._lines[-1] += " # skipped"
cmtor.append_source() cmtor.append_source("")
cmtor.indent -= flags.INDENT cmtor.indent -= flags.INDENT
if self.orelse: if self.orelse:
@ -45,16 +44,16 @@ def If(self, cmtor, state=0):
cmtor.indent += flags.INDENT cmtor.indent += flags.INDENT
for stmt in self.orelse: for stmt in self.orelse:
if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES: if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt)) cmtor.append_source(cmtor.to_source(stmt))
if test: if test:
cmtor.process(stmt) cmtor.process(stmt)
test = cmtor._stack_event == flags.NORMAL test = cmtor._stack_event == flags.NORMAL
cmtor.append_source() cmtor.append_source("")
cmtor.indent -= flags.INDENT cmtor.indent -= flags.INDENT
def For(self, cmtor): def For(self, cmtor):
cmtor.append_source(to_source(ast.For(self.target, self.iter, [], []))) cmtor.append_source(cmtor.to_source(ast.For(self.target, self.iter, [], [])))
loop_start: int = cmtor.next_line() loop_start: int = cmtor.next_line()
@ -68,15 +67,15 @@ def For(self, cmtor):
# enter new iteration (mantain locals()) # enter new iteration (mantain locals())
cmtor.append_source("") cmtor.append_source("")
cmtor.append_source("###### !new iteration! ######") # cmtor.append_source("###### !new iteration! ######")
cmtor._locals[REG_it] = it cmtor._locals[REG_it] = it
stmt = ast.Assign([self.target], ast.Name(REG_it, ast.Load())) stmt = ast.Assign([self.target], ast.Name(REG_it, ast.Load()))
cmtor.process(stmt) cmtor.exec(stmt)
# process body # process body
for stmt in self.body: for stmt in self.body:
if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES: if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt)) cmtor.append_source(cmtor.to_source(stmt))
if cmtor._stack_event == flags.NORMAL: if cmtor._stack_event == flags.NORMAL:
cmtor.process(stmt) cmtor.process(stmt)
else: else:
@ -92,12 +91,13 @@ def For(self, cmtor):
cmtor.indent -= flags.INDENT cmtor.indent -= flags.INDENT
# comment out all code except for the last iter # delete all code except for the last iter
for lineno in range(loop_start, last_iter_start): for lineno in range(loop_start, last_iter_start):
if cmtor._lines_category[lineno][0] == flags.SOURCE: cmtor._lines[lineno] = None
line: str = cmtor._lines[lineno] # if cmtor._lines_category[lineno][0] == flags.SOURCE:
if line.lstrip() and line.lstrip()[0] != "#": # line: str = cmtor._lines[lineno]
cmtor._lines[lineno] = " " * self_indent + "# " + line[self_indent:] # if line.lstrip() and line.lstrip()[0] != "#":
# cmtor._lines[lineno] = " " * self_indent + "# " + line[self_indent:]
def Break(self, cmtor): def Break(self, cmtor):

View File

@ -1,8 +1,7 @@
from .. import flags from .. import flags
from ..utils import to_source
def FunctionDef(self, cmtor): def FunctionDef(self, cmtor):
cmtor.append_source(f"def {self.name}({to_source(self.args)}):") cmtor.append_source(f"def {self.name}({cmtor.to_source(self.args)}):")
cmtor.indent += flags.INDENT cmtor.indent += flags.INDENT
if self is cmtor.root: if self is cmtor.root:
@ -13,7 +12,7 @@ def FunctionDef(self, cmtor):
for stmt in self.body: for stmt in self.body:
if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES: if type(stmt) not in flags.APPEND_SOURCE_BY_THEMSELVES:
cmtor.append_source(to_source(stmt)) cmtor.append_source(cmtor.to_source(stmt))
if self is cmtor.root: if self is cmtor.root:
cmtor.process(stmt) cmtor.process(stmt)
@ -26,3 +25,7 @@ def FunctionDef(self, cmtor):
def Return(self, cmtor): def Return(self, cmtor):
cmtor.process(self.value) cmtor.process(self.value)
cmtor._return = cmtor.eval(self.value, format=False) cmtor._return = cmtor.eval(self.value, format=False)
def Lambda(self, cmtor):
pass

View File

@ -55,3 +55,8 @@ def Slice(self, cmtor):
def keyword(self, cmtor): def keyword(self, cmtor):
cmtor.process(self.value) cmtor.process(self.value)
def IfExp(self, cmtor):
cmtor.append_comment(cmtor.eval(self.test))
cmtor.append_comment(cmtor.eval(self))

View File

@ -10,3 +10,21 @@ def Tuple(self, cmtor):
def List(self, cmtor): def List(self, cmtor):
for x in self.elts: for x in self.elts:
cmtor.process(x) cmtor.process(x)
def Set(self, cmtor):
for x in self.elts:
cmtor.process(x)
def Dict(self, cmtor):
for k, v in zip(self.keys, self.values):
cmtor.process(v)
def FormattedValue(self, cmtor):
pass
def JoinedStr(self, cmtor):
cmtor.append_comment(cmtor.eval(self))

View File

@ -6,7 +6,16 @@ def Pass(self, cmtor):
def Assign(self, cmtor): def Assign(self, cmtor):
cmtor.process(self.value) cmtor.process(self.value)
cmtor.exec(self) cmtor.exec(self)
if type(self.value) not in flags.ASSIGN_SILENT: if type(self.value) not in flags.ASSIGN_SILENT and len(self.targets):
cmtor.append_comment(f"----------") cmtor.append_comment(f"----------")
for target in self.targets: for target in self.targets:
cmtor.process(target) cmtor.process(target)
def AnnAssign(self, cmtor):
if getattr(self, "value", None) is not None:
cmtor.process(self.value)
cmtor.exec(self)
if type(self.value) not in flags.ASSIGN_SILENT:
cmtor.append_comment(f"----------")
cmtor.process(self.target)

View File

@ -19,10 +19,11 @@ def sign(line: str, depth=1):
return line return line
def to_source(node): def to_source(node, indent=0):
src = astor.to_source(node).rstrip("\n") src = astor.to_source(node).rstrip("\n")
if type(node) != ast.Tuple and len(src) > 2 and src[0] == "(" and src[-1] == ")": if type(node) != ast.Tuple and len(src) > 2 and src[0] == "(" and src[-1] == ")":
src = src[1:-1] src = src[1:-1]
src = src.replace("\n", "\n" + " " * indent)
return src return src
@ -30,7 +31,7 @@ def dump(node, file=sys.stderr):
print(ast.dump(node, indent=4), file=file) print(ast.dump(node, indent=4), file=file)
def comment_to_file(code, file: str) -> bool: def comment_to_file(code, file: str, before="", after="") -> bool:
if file == "<return>": if file == "<return>":
return False return False
elif file == "<stderr>": elif file == "<stderr>":
@ -38,18 +39,18 @@ def comment_to_file(code, file: str) -> bool:
syntax = Syntax(code, "python") syntax = Syntax(code, "python")
rich.print(syntax, file=sys.stderr) rich.print(syntax, file=sys.stderr)
else: else:
rich.print(code, file=sys.stderr) print(before + code + after, file=sys.stderr)
elif file == "<stdout>": elif file == "<stdout>":
if sys.stdout.isatty(): if sys.stdout.isatty():
syntax = Syntax(code, "python") syntax = Syntax(code, "python")
rich.print(syntax, file=sys.stdout) rich.print(syntax, file=sys.stdout)
else: else:
rich.print(code, file=sys.stdout) print(before + code + after, file=sys.stdout)
elif isinstance(file, IOBase): elif isinstance(file, IOBase):
rich.print(code, file=file) print(code, file=file)
elif type(file) == str: elif type(file) == str:
with open(file, "wt") as f: with open(file, "wt") as f:
rich.print(code, file=f) print(before + code + after, file=f)
else: else:
raise NotImplementedError(f"Unknown file protocal {file}") raise NotImplementedError(f"Unknown file protocal {file}")
return True return True