此工程解读链接(建议按顺序阅读):
其实在看这里的代码的时候感觉是最轻松的,但同时又是最费时间的。轻松是因为这里的代码大体上做了些什么都比较好懂,费时间是因为里面涉及了很多python的运算操作,一层套一层,如果不是非常熟练的话(比如说我)看起来还是有点尴尬。
所以我在这里强烈推荐像我一样对这里python操作不太熟练的小伙伴一步步debug看一下,或者说把部分代码粘出来,自己写个小文本文件load进去看一下,还是十分有帮助的。
功夫下得深,铁杵磨成针!同志们我们就快把这个工程搞定了,再坚持一下!
#-*-coding:utf-8-*-
import codecs
import os
import collections
from six.moves import cPickle
import numpy as np
class TextLoader():
def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):
self.data_dir = data_dir
self.batch_size = batch_size
self.seq_length = seq_length
self.encoding = encoding
#读入数据文件,一般我们只准备了第一个txt文件,后面再生成后连两个文件
input_file = os.path.join(data_dir, "input.txt")
vocab_file = os.path.join(data_dir, "vocab.pkl")
tensor_file = os.path.join(data_dir, "data.npy")
if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):
print("reading text file")
self.preprocess(input_file, vocab_file, tensor_file)
else:
print("loading preprocessed files")
self.load_preprocessed(vocab_file, tensor_file)
self.create_batches()
self.reset_batch_pointer()
def preprocess(self, input_file, vocab_file, tensor_file):
with codecs.open(input_file, "r", encoding=self.encoding) as f:
data = f.read()
#统计一共多少字,相当于用了list(set(data)),
#collection.Counter这个python模块真是太方便了,很方便的统计次数,推荐!
counter = collections.Counter(data)
#相当于counter.most_commen(),按照次数排序
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
#压缩成一个[(letters),(frequencies)]的形式
#得到的char就是(letters),里面不重复的按照从高到低的顺序存放着每个字符
self.chars, _ = zip(*count_pairs)
self.vocab_size = len(self.chars)
#vocab是一个字典,存放着每个字符对应着的出现次数
self.vocab = dict(zip(self.chars, range(len(self.chars))))
#vocab_file.pkl里存放了一个tuple,里面不重复的按照从高到低的顺序存放着每个字符
with open(vocab_file, 'wb') as f:
cPickle.dump(self.chars, f)
#data.npy(tensorfile)里存放了每个字符的出现次数,
# 这里一共有1115394个字符,那它就是一个长度为1115394的numpy array
self.tensor = np.array(list(map(self.vocab.get, data)))
np.save(tensor_file, self.tensor)
#载入变量
def load_preprocessed(self, vocab_file, tensor_file):
with open(vocab_file, 'rb') as f:
self.chars = cPickle.load(f)
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
self.tensor = np.load(tensor_file)
self.num_batches = int(self.tensor.size / (self.batch_size *
self.seq_length))
def create_batches(self):
self.num_batches = int(self.tensor.size / (self.batch_size *
self.seq_length))
# When the data (tensor) is too small, let's give them a better error message
if self.num_batches==0:
assert False, "Not enough data. Make seq_length and batch_size small."
#self.num_batches * self.batch_size * self.seq_length
# 这个东西不就是上面用到的tensor.size吗。。。
self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length]
xdata = self.tensor
ydata = np.copy(self.tensor)
#回想一下rnn的输入和输出,x和y要错一位,这里没有设置<begin>和<end>
ydata[:-1] = xdata[1:]
ydata[-1] = xdata[0]
#将x和y按batch_size切成了很多batch
#在这里他们是是有446个batch的list,即[[...],[...],[...],...]
self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1)
self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1)
def next_batch(self):
x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]
#pointer每个batch过后向后移动一位
self.pointer += 1
return x, y
def reset_batch_pointer(self):
self.pointer = 0
python train.py
#更改参数的话举例如下
python train.py —num_layers 2
看到这里,大家应该已经对这个工程有一个大体的了解了,我们的任务也已经进行了85%了。下面的sample是干什么的呢?因为训练完了我们要看一下效果嘛,所以写了一个sample程序跑一下看看到底我们的模型可以生成什么样的文本,一起来看一下吧。
参考资料: