Embedding是什么?详解nn.Embedding
Embedding是什么?
其为一个简单的存储固定大小的词典的嵌入向量的查找表,意思就是说,给一个编号,嵌入层就能返回这个编号对应的嵌入向量,嵌入向量反映了各个编号对应的符号的语义信息(蕴含了所有符号的语义关系)。
输入为一个编号列表,输出为对应的符号嵌入向量列表。
pytorch中的使用
#建立词向量层
embed = torch.nn.Embedding(n_vocabulary,embedding_size)
简单解释
embeding是一个词典,可以学习。
如:nn.Embedding(2, num_hiddens)
就是一个embedding。
输入索引,可以查到对应的向量值。
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
就是一个segment的索引矩阵,用self.segment_embedding(segments)
之后就可以得到它的向量值。
引用