feat(trace_commentor): 添加对ast.ExtSlice和nn.Parameter的支持
- 在`formatters/__init__.py`中添加了`Parameter`函数,用于判断对象是否为`Parameter`类型。 - 在`formatters/desc.py`中的`desc`函数里添加了对`Parameter`对象的描述逻辑,包括数据类型和形状的格式化输出。 - 在`handlers/__init__.py`中将`Index`替换为`ExtSlice`,以支持扩展切片的处理。 - 在`handlers/expressions.py`中添加了`ExtSlice`函数,用于处理扩展切片表达式。
This commit is contained in:
parent
963a65fa71
commit
e0e0216312
36
pyproject.toml
Normal file
36
pyproject.toml
Normal file
@ -0,0 +1,36 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=65.5.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "trace_commentor" # pip install
|
||||
version = "2025.7.1"
|
||||
description = "..."
|
||||
readme = "README.md" # 可选, 如果你有 README.md 文件
|
||||
requires-python = ">=3.8"
|
||||
license = {text = "Apache-2.0"} # 或者你的许可证
|
||||
authors = [
|
||||
{name = "Yuyao Huang (Sam)", email = "huangyuyao@outlook.com"},
|
||||
]
|
||||
dependencies = [
|
||||
"astor",
|
||||
"rich",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["trace_commentor"]
|
||||
|
||||
# 包含非 Python 文件
|
||||
# [tool.setuptools.package-data]
|
||||
# "package.path" = ["*.json"]
|
||||
|
||||
# 命令行程序入口
|
||||
# [project.scripts]
|
||||
# script-name = "module.path:function"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"pytest-mock",
|
||||
]
|
||||
@ -10,9 +10,14 @@ def Tensor(tensor):
|
||||
return "Tensor" in tensor.__class__.__name__
|
||||
|
||||
|
||||
def Parameter(parameter):
|
||||
return "Parameter" in parameter.__class__.__name__
|
||||
|
||||
|
||||
LIST = [
|
||||
(callable, silent),
|
||||
(Tensor, desc),
|
||||
(Parameter, desc),
|
||||
(types.ModuleType, silent),
|
||||
((list, dict), desc),
|
||||
]
|
||||
|
||||
@ -24,6 +24,9 @@ def _apply(x, f):
|
||||
|
||||
def desc(x):
|
||||
def _desc(_x):
|
||||
if 'Parameter' in _x.__class__.__name__:
|
||||
dtype = str(_x.data.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
|
||||
return _print(f"Parameter({tuple(_x.data.shape)}, {dtype})")
|
||||
if "Tensor" in _x.__class__.__name__:
|
||||
dtype = str(_x.dtype).replace("torch.", "").replace("float", "f").replace("uint", "u").replace("int", "i")
|
||||
return _print(f"Tensor({tuple(_x.shape)}, {dtype})")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from .definitions import FunctionDef, Return, Lambda
|
||||
from .statements import Pass, Assign, AnnAssign, AugAssign
|
||||
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, keyword, IfExp, Index
|
||||
from .expressions import Expr, UnaryOp, BinOp, Call, Compare, Attribute, Subscript, Slice, ExtSlice, keyword, IfExp, Index
|
||||
from .literals import Constant, Tuple, List, Set, Dict, FormattedValue, JoinedStr
|
||||
from .variables import Name
|
||||
from .control_flow import If, For, Continue, Break, With, Raise
|
||||
|
||||
@ -55,6 +55,11 @@ def Slice(self, cmtor):
|
||||
cmtor.process(self.step)
|
||||
|
||||
|
||||
def ExtSlice(self, cmtor):
|
||||
for dim in self.dims:
|
||||
cmtor.process(dim)
|
||||
|
||||
|
||||
def keyword(self, cmtor):
|
||||
cmtor.process(self.value)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user