From e0e0216312d51854781fbe2b059fdaf4778d74e6 Mon Sep 17 00:00:00 2001 From: "Yuyao Huang (Sam)" Date: Fri, 4 Jul 2025 12:57:12 +0800 Subject: [PATCH] =?UTF-8?q?feat(trace=5Fcommentor):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=AF=B9ast.ExtSlice=E5=92=8Cnn.Parameter=E7=9A=84=E6=94=AF?= =?UTF-8?q?=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在`formatters/__init__.py`中添加了`Parameter`函数,用于判断对象是否为`Parameter`类型。 - 在`formatters/desc.py`中的`desc`函数里添加了对`Parameter`对象的描述逻辑,包括数据类型和形状的格式化输出。 - 在`handlers/__init__.py`中将`Index`替换为`ExtSlice`,以支持扩展切片的处理。 - 在`handlers/expressions.py`中添加了`ExtSlice`函数,用于处理扩展切片表达式。 --- pyproject.toml | 36 +++++++++++++++++++++++++ trace_commentor/formatters/__init__.py | 5 ++++ trace_commentor/formatters/desc.py | 3 +++ trace_commentor/handlers/__init__.py | 2 +- trace_commentor/handlers/expressions.py | 5 ++++ 5 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d235d3f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=65.5.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "trace_commentor" # pip install +version = "2025.7.1" +description = "..." +readme = "README.md" # 可选, 如果你有 README.md 文件 +requires-python = ">=3.8" +license = {text = "Apache-2.0"} # 或者你的许可证 +authors = [ + {name = "Yuyao Huang (Sam)", email = "huangyuyao@outlook.com"}, +] +dependencies = [ + "astor", + "rich", +] + +[tool.setuptools] +py-modules = ["trace_commentor"] + +# 包含非 Python 文件 +# [tool.setuptools.package-data] +# "package.path" = ["*.json"] + +# 命令行程序入口 +# [project.scripts] +# script-name = "module.path:function" + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-cov", + "pytest-mock", +] diff --git a/trace_commentor/formatters/__init__.py b/trace_commentor/formatters/__init__.py index 2fe6c06..40b9fb6 100644 --- a/trace_commentor/formatters/__init__.py +++ b/trace_commentor/formatters/__init__.py @@ -10,9 +10,14 @@ def Tensor(tensor): return "Tensor" in tensor.__class__.__name__ +def Parameter(parameter): + return "Parameter" in parameter.__class__.__name__ + + LIST = [ (callable, silent), (Tensor, desc), + (Parameter, desc), (types.ModuleType, silent), ((list, dict), desc), ] diff --git a/trace_commentor/formatters/desc.py b/trace_commentor/formatters/desc.py index 374b95c..8ebde39 100644 --- a/trace_commentor/formatters/desc.py +++ b/trace_commentor/formatters/desc.py @@ -24,6 +24,9 @@ def _apply(x, f): def desc(x): def _desc(_x): + if 'Parameter' in _x.__class__.__name__: + dtype = str(_x.data.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i") + return _print(f"Parameter({tuple(_x.data.shape)}, {dtype})") if "Tensor" in _x.__class__.__name__: dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i") return _print(f"Tensor({tuple(_x.shape)}, {dtype})") diff --git a/trace_commentor/handlers/__init__.py b/trace_commentor/handlers/__init__.py index ebe60e4..da9a0f7 100644 --- a/trace_commentor/handlers/__init__.py +++ b/trace_commentor/handlers/__init__.py @@ -1,6 +1,6 @@ from .definitions import FunctionDef, Return, Lambda from .statements import Pass, Assign, AnnAssign, AugAssign -from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword, IfExp, Index +from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, ExtSlice, keyword, IfExp, Index from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr from .variables import Name from .control_flow import If, For, Continue, Break, With, Raise diff --git a/trace_commentor/handlers/expressions.py b/trace_commentor/handlers/expressions.py index 76f2fb2..f2e5927 100644 --- a/trace_commentor/handlers/expressions.py +++ b/trace_commentor/handlers/expressions.py @@ -55,6 +55,11 @@ def Slice(self, cmtor): cmtor.process(self.step) +def ExtSlice(self, cmtor): + for dim in self.dims: + cmtor.process(dim) + + def keyword(self, cmtor): cmtor.process(self.value)