重用预训练的嵌入


重用预训练的嵌入

TensorFlow Hub项目可以轻松地在自己的模型中重用经过预训练的模型组件。这些模型组件称为模块。只需要浏览TF Hub储存库,就能找到需要的,然后将代码示例复制到下项目中,该模块将连同其预先训练的权重一起自动下载并包含在模型中:

# 在情感分析模型中使用nnlm-en-dim50句子嵌入模块:
import tensorflow_hub as hub
import tensorflow as tf
from tensorflow import keras

model = keras.models.Sequential([
    hub.KerasLayer('https://tfhub.dev/google/tf2-preview/nnlm-en-dim50/1', dtype=tf.string, input_shape=[],
                   output_shape=[50]),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(1, activation='relu')
])
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

hub.KerasLayer层会从给定的URL下载模块。这个特殊模块是一个句子编码器:它把字符串作为输入,并将每个字符串编码为单个向量(在这个情况下为50维向量)。在内部,它解析字符串(用空格来分隔单词),并使用大型语料库(Google News 7B语料库,长70亿个单词)上预训练的嵌入矩阵来嵌入每个单词。最后,它将计算所有词嵌入的均值,其结果就是句子嵌入。然后可以添加两个简单的Dense层来创建一个连搞得情感分析模型。默认情况下,hub.KerasLayer是不可训练的,但是可以在创建它时设置trainable=True来更改它。

接下来只需要加载IMDB评论数据集即可,无需对其进行预处理(除了批处理和预取)并直接训练模型:

import tensorflow_datasets as tfds

datasets, info = tfds.load('imdb_reviews', as_supervised=True, with_info=True)
train_size = info.splits['train'].num_examples
batch_size = 32
train_set = datasets['train'].batch(batch_size).prefetch(1)
history = model.fit(train_set, epochs=5)

TF Hub模块的URL最后一部分指定想要模型的版本1。此版本控制可以确保如果发布了新的模块版本,不会破坏模型。方便的是,如果在网络浏览器中输入此URL,会得到此模块的文档。默认情况下,TF Hub会将下载的文件缓存到本地系统的临时目录中。

相关