TimeDistributed()层应用于自定义层出现NotImplementedError

news/2024/11/14 15:06:01 标签: python, 深度学习, tensorflow, TimeDistributed, keras

 当使用tensorflow.keras.layers中的TimeDistributed应用于自定义层在时间维度进行扩展时,使输入数据在时间维度上的每个数据应用于相同的自定义层(或base_model),如:

model = Sequential()
model.add(TimeDistributed(base_model, input_shape=(15, 244, 244, 3)))

 训练时出现如下错误:

File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/layers/wrappers.py", line 210, in compute_output_shape
    child_output_shape = self.layer.compute_output_shape(child_input_shape)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/base_layer.py", line 639, in compute_output_shape
    raise NotImplementedError
NotImplementedError

该提示表明,没有重写如下compute_output_shape函数,从而引起NotImplementedError。

def compute_output_shape(self, input_shape):
    ... ...
    return tuple(output_shape)
  

 即从写compute_output_shape()函数,确定其输出形状,实现这个方法是自定义层(或模型)中有需要训练的参数。对于相对简单的自定义操作,可以通过Lambda层进行实现,但该实现方式的局限性只针对我们的自定义层中(或模型)不包含可训练的权重,否则可能出现错误,但是如果该层(或模型)中可训练的参数不进行更新,那个也可用该Lambda层进行实现,其实现方式如下所示:

# 2048 is the output size
model.add(
    Lambda(
        lambda x: tf.reshape(base_model(tf.reshape(x, [-1, 244, 244,3])),[-1, 15, 2048])
    , input_shape=(15, 244, 244, 3))
)

参考资料:

  1. (39条消息) keras实现自定义层的关键步骤解析_MIss-Y的博客-CSDN博客
  2. conv neural network - TimeDistributed of a KerasLayer in Tensorflow 2.0 - Stack Overflow
  3. How to Use the TimeDistributed Layer in Keras (machinelearningmastery.com)
  4. (39条消息) Keras学习笔记(二)Keras实现自定义层_buchidanhuang的博客-CSDN博客_keras 自定义层
  5. (39条消息) tensorflow2.x踩坑记录二:加载含Lambda层的模型时,出现name tf is not defined_耐心的小黑的博客-CSDN博客

http://www.niftyadmin.cn/n/1737978.html

相关文章

显示acc和lose时出现:KeyError: ‘sparse_categorical_accuracy‘

使用:historymodel.fit(训练集数据, 训练集标签, batch_size, epochs, validation_split用作测试数据的比例, validation_data测试集, validation_freq测试频率)训练之后,绘制其精度acc和损失loss函数曲线。 根据fit()函数的传入…

C++虚函数(三)

三. 虚函数使用技巧 3.1 private的虚函数   考虑下面的例子: class A{public:void foo() { bar();}private:virtual void bar() { ...}}; class B: public A{private:virtual void bar() { ...}}; 在这个例子中,虽然bar()在A类中是private的&#…

TensorFlow2.X绘制常见图像(如AUC,acc,recall等等)

前人已写,所以不重复造轮子了,顾粘上相关链接: https://www.freesion.com/article/5668431209/#METRICS_7 解释一下代码中用到的color colors[0],需要自定义相关颜色的列表,这里我们可以定义为: colors…

Keras中那些学习率衰减方法

(43条消息) Keras中那些学习率衰减策略_Siucaan-CSDN博客

常用脚本汇总

随机访问文件中的位置,但是需要保证不取到重复的数值 awk { print rand(),$1 } file.txt|sort -k1 |awk { print $2 } >result.txt awk { print rand(),$1 } file.txt在第一列加上随机数 awk { print rand(),$1 } file.txt|sort -k1按照第一列随机数排序 awk { p…

tensorflow2.x训练模型出现nan

1.报如下错误: tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [predictions must be > 0] [Condition x > y did not hold element-wise:] [x (sub_2:0) ] [[-nan][-nan][-nan]...] [y (Cast_2/x:0) ] [0][[{{node a…

损失函数softmax_cross_entropy、binary_cross_entropy、sigmoid_cross_entropy之间的区别与联系

cross_entropy-----交叉熵是深度学习中常用的一个概念,一般用来求目标与预测值之间的差距。 在介绍softmax_cross_entropy,binary_cross_entropy、sigmoid_cross_entropy之前,先来回顾一下信息量、熵、交叉熵等基本概念。 ----------------…

NetLog 大规模应用实战:Database-sharding 技术

一、背景 Netlog是一家社交网站社区,目前拥有大规模的应用数据,包括: 超过4000w的活跃用户数、每个月5000w的UV、每月50亿的PV、每月60亿的在线时长、支持26中语言,覆盖5个主要的欧洲国家,如意大利、德国,土…