|
1 报错描述
1.1 系统环境
Hardware Environment(Ascend/GPU/CPU): Ascend
Software Environment:
– MindSpore version (source or binary): 1.8.0
– Python version (e.g., Python 3.7.5): 3.7.6
– OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 4.15.0-74-generic
– GCC/Compiler version (if compiled from source):
1.2 基本信息
1.2.1 脚本
训练脚本是通过构建CellList的单算子网络,实现cell列表容器。脚本如下:
01 class ListNoneExample(nn.Cell):
02 def __init__(self):
03 super(ListNoneExample, self).__init__()
04 self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])
05
06 def construct(self, x):
07 output = []
08 for op in self.lst:
09 output.append(op(x))
10 return output
11
12 input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
13 example = ListNoneExample()
14 output = example(input)
15 print("Output:", output)1.2.2 报错
这里报错信息如下:
Traceback (most recent call last):
File &#34;C:/Users/l30026544/PycharmProjects/q2_map/new/I3OGVW.py&#34;, line 31, in <module>
example = ListNoneExample()
File &#34;C:/Users/l30026544/PycharmProjects/q2_map/new/I3OGVW.py&#34;, line 19, in __init__
self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])
File &#34;C:\Users\l30026544\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py&#34;, line 310, in __init__
self.extend(args[0])
File &#34;C:\Users\l30026544\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py&#34;, line 405, in extend
if _valid_cell(cell, cls_name):
File &#34;C:\Users\l30026544\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py&#34;, line 39, in _valid_cell
raise TypeError(f&#39;{msg_prefix} each cell should be subclass of Cell, but got {type(cell).__name__}.&#39;)
TypeError: For &#39;CellList&#39;, each cell should be subclass of Cell, but got NoneType.原因分析
我们看报错信息,在TypeError中,写到For ‘CellList’, each cell should be subclass of Cell, but got NoneType.
,意思是对于CellList这个算子, 传入的每一个cell都因该是nn.Cell的子类, 但是得到了None类型。检查网络中初始化CellList的行为第4行, 发现传入了一个None, 因此报错。为了解决这个问题, 只需把这里的None换成一个继承于基类Cell类的对象, 就能实现相同的功能。
2 解决方法
基于上面已知的原因,很容易做出如下修改:
01 class NoneCell(nn.Cell):
02 def __init__(self):
03 super(NoneCell, self).__init__()
04
05 def construct(self, x):
06 return x
07
08 class ListNoneExample(nn.Cell):
09 def __init__(self):
10 super(ListNoneExample, self).__init__()
11 self.lst = nn.CellList([nn.ReLU(), NoneCell(), nn.ReLU()])
12
13 def construct(self, x):
14 output = []
15 for op in self.lst:
16 output.append(op(x))
17 return output
18
19 input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
20 example = ListNoneExample()
21 output = example(input)
22 print(&#34;Output:&#34;, output)此时执行成功,输出如下:
Output: (Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[0.00000000e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[-2.74355006e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[0.00000000e+000]]))3 总结
定位报错问题的步骤:
1、找到报错的用户代码行:self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()]);
2、 根据日志报错信息中的关键字,缩小分析问题的范围each cell should be subclass of Cell, but got NoneType ;
3、需要重点关注变量定义、初始化的正确性。
4 参考文档
4.1 CellList算子API接口 |
|