一、GRU介绍
GRU是LSTM网络的一种效果很好的变体,它较LSTM网络的结构更加简单,而且效果也很好,因此也是当前非常流形的一种网络。GRU既然是LSTM的变体,因此也是可以解决RNN网络中的长依赖问题。
GRU的参数较少,因此训练速度更快,GRU能够降低过拟合的风险。
在LSTM中引入了三个门函数:输入门、遗忘门和输出门来控制输入值、记忆值和输出值。而在GRU模型中只有两个门:分别是更新门和重置门。具体结构如下图所示:
·
图中的zt和rt分别表示更新门和重置门。更新门用于控制前一时刻的状态信息被带入到当前状态中的程度,更新门的值越大说明前一时刻的状态信息带入越多。重置门控制前一状态有多少信息被写入到当前的候选集 h~t
二、GRU与LSTM的比较
- GRU相比于LSTM少了输出门,其参数比LSTM少。
- GRU在复调音乐建模和语音信号建模等特定任务上的性能和LSTM差不多,在某些较小的数据集上,GRU相比于LSTM表现出更好的性能。
- LSTM比GRU严格来说更强,因为它可以很容易地进行无限计数,而GRU却不能。这就是GRU不能学习简单语言的原因,而这些语言是LSTM可以学习的。
- GRU网络在首次大规模的神经网络机器翻译的结构变化分析中,性能始终不如LSTM。
三、GRU的API
rnn = nn.GRU(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional)
初始化:
input_size: input的特征维度
hidden_size: 隐藏层的宽度
num_layers: 单元的数量(层数),默认为1,如果为2以为着将两个GRU堆叠在一起,当成一个GRU单元使用。
bias: True or False,是否使用bias项,默认使用
batch_first: Ture or False, 默认的输入是三个维度的,即:(seq, batch, feature),第一个维度是时间序列,第二个维度是batch,第三个维度是特征。如果设置为True,则(batch, seq, feature)。即batch,时间序列,每个时间点特征。
dropout:设置隐藏层是否启用dropout,默认为0
bidirectional:True or False, 默认为False,是否使用双向的GRU,如果使用双向的GRU,则自动将序列正序和反序各输入一次。
输入:
rnn(input, h_0)
输出:
output, hn = rnn(input, h0)
形状的和LSTM差不多,也有双向
四、情感分类demo修改成GRU
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import os
import re
import pickle
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
dataset_path = r'C:Usersci21615DownloadsaclImdb_v1aclImdb'
MAX_LEN = 500
def tokenize(text):
"""
分词,处理原始文本
:param text:
:return:
"""
fileters = ['!', '"', '#', '$', '%', '&', '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '', '?', '@'
, '[', '\', ']', '^', '_', '`', '{', '|', '}', '~', 't', 'n', 'x97', 'x96', '”', '“', ]
text = re.sub("<.>", " ", text, flags=re.S)
text = re.sub("|".join(fileters), " ", text, flags=re.S)
return [i.strip() for i in text.split()]
class ImdbDataset(Dataset):
"""
准备数据集
"""
def __init__(self, mode):
super(ImdbDataset, self).__init__()
if mode == 'train':
text_path = [os.path.join(dataset_path, i) for i in ['train/neg', 'train/pos']]
else:
text_path = [os.path.join(dataset_path, i) for i in ['test/neg', 'test/pos']]
self.total_file_path_list = []
for i in text_path:
self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])
def __getitem__(self, item):
cur_path = self.total_file_path_list[item]
cur_filename = os.path.basename(cur_path)
# 获取标签
label_temp = int(cur_filename.split('_')[-1].split('.')[0]) - 1
label = 0 if label_temp min}
# 删除词频大于max的word
if max is not None:
self.count = {word:value for word,value in self.count.items() if value len(sentence):
sentence = sentence + [self.PAD_TAG] * (max_len - len(sentence)) # 填充
if max_len
结果展示:
服务器托管,北京服务器托管,服务器租用 http://www.fwqtg.net
机房租用,北京机房租用,IDC机房托管, http://www.e1idc.net