feat(trace_commentor): 添加对ast.ExtSlice和nn.Parameter的支持

- 在`formatters/__init__.py`中添加了`Parameter`函数,用于判断对象是否为`Parameter`类型。
- 在`formatters/desc.py`中的`desc`函数里添加了对`Parameter`对象的描述逻辑,包括数据类型和形状的格式化输出。
- 在`handlers/__init__.py`中将`Index`替换为`ExtSlice`,以支持扩展切片的处理。
- 在`handlers/expressions.py`中添加了`ExtSlice`函数,用于处理扩展切片表达式。
This commit is contained in:
Yuyao Huang (Sam) 2025-07-04 12:57:12 +08:00
parent 963a65fa71
commit e0e0216312
5 changed files with 50 additions and 1 deletions

36
pyproject.toml Normal file
View File

@ -0,0 +1,36 @@
[build-system]
requires = ["setuptools>=65.5.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "trace_commentor" # pip install
version = "2025.7.1"
description = "..."
readme = "README.md" # 可选, 如果你有 README.md 文件
requires-python = ">=3.8"
license = {text = "Apache-2.0"} # 或者你的许可证
authors = [
{name = "Yuyao Huang (Sam)", email = "huangyuyao@outlook.com"},
]
dependencies = [
"astor",
"rich",
]
[tool.setuptools]
py-modules = ["trace_commentor"]
# 包含非 Python 文件
# [tool.setuptools.package-data]
# "package.path" = ["*.json"]
# 命令行程序入口
# [project.scripts]
# script-name = "module.path:function"
[project.optional-dependencies]
dev = [
"pytest",
"pytest-cov",
"pytest-mock",
]

View File

@ -10,9 +10,14 @@ def Tensor(tensor):
return "Tensor" in tensor.__class__.__name__
def Parameter(parameter):
return "Parameter" in parameter.__class__.__name__
LIST = [
(callable, silent),
(Tensor, desc),
(Parameter, desc),
(types.ModuleType, silent),
((list, dict), desc),
]

View File

@ -24,6 +24,9 @@ def _apply(x, f):
def desc(x):
def _desc(_x):
if 'Parameter' in _x.__class__.__name__:
dtype = str(_x.data.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
return _print(f"Parameter({tuple(_x.data.shape)}, {dtype})")
if "Tensor" in _x.__class__.__name__:
dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
return _print(f"Tensor({tuple(_x.shape)}, {dtype})")

View File

@ -1,6 +1,6 @@
from .definitions import FunctionDef, Return, Lambda
from .statements import Pass, Assign, AnnAssign, AugAssign
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword, IfExp, Index
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, ExtSlice, keyword, IfExp, Index
from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr
from .variables import Name
from .control_flow import If, For, Continue, Break, With, Raise

View File

@ -55,6 +55,11 @@ def Slice(self, cmtor):
cmtor.process(self.step)
def ExtSlice(self, cmtor):
for dim in self.dims:
cmtor.process(dim)
def keyword(self, cmtor):
cmtor.process(self.value)