您的当前位置:首页正文

RNN代码解读之char-RNN with TensorFlow(util.py)

2024-11-28 来源:个人技术集锦

此工程解读链接(建议按顺序阅读):



其实在看这里的代码的时候感觉是最轻松的,但同时又是最费时间的。轻松是因为这里的代码大体上做了些什么都比较好懂,费时间是因为里面涉及了很多python的运算操作,一层套一层,如果不是非常熟练的话(比如说我)看起来还是有点尴尬。

所以我在这里强烈推荐像我一样对这里python操作不太熟练的小伙伴一步步debug看一下,或者说把部分代码粘出来,自己写个小文本文件load进去看一下,还是十分有帮助的。

功夫下得深,铁杵磨成针!同志们我们就快把这个工程搞定了,再坚持一下!

#-*-coding:utf-8-*-
import codecs
import os
import collections
from six.moves import cPickle
import numpy as np

class TextLoader():
    def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.encoding = encoding
        #读入数据文件,一般我们只准备了第一个txt文件,后面再生成后连两个文件
        input_file = os.path.join(data_dir, "input.txt")
        vocab_file = os.path.join(data_dir, "vocab.pkl")
        tensor_file = os.path.join(data_dir, "data.npy")

        if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):
            print("reading text file")
            self.preprocess(input_file, vocab_file, tensor_file)
        else:
            print("loading preprocessed files")
            self.load_preprocessed(vocab_file, tensor_file)
        self.create_batches()
        self.reset_batch_pointer()

    def preprocess(self, input_file, vocab_file, tensor_file):
        with codecs.open(input_file, "r", encoding=self.encoding) as f:
            data = f.read()
        #统计一共多少字,相当于用了list(set(data)),
        #collection.Counter这个python模块真是太方便了,很方便的统计次数,推荐!
        counter = collections.Counter(data)
        #相当于counter.most_commen(),按照次数排序
        count_pairs = sorted(counter.items(), key=lambda x: -x[1])
        #压缩成一个[(letters),(frequencies)]的形式
        #得到的char就是(letters),里面不重复的按照从高到低的顺序存放着每个字符
        self.chars, _ = zip(*count_pairs)
        self.vocab_size = len(self.chars)
        #vocab是一个字典,存放着每个字符对应着的出现次数
        self.vocab = dict(zip(self.chars, range(len(self.chars))))
        #vocab_file.pkl里存放了一个tuple,里面不重复的按照从高到低的顺序存放着每个字符
        with open(vocab_file, 'wb') as f:
            cPickle.dump(self.chars, f)
        #data.npy(tensorfile)里存放了每个字符的出现次数,
        # 这里一共有1115394个字符,那它就是一个长度为1115394的numpy array
        self.tensor = np.array(list(map(self.vocab.get, data)))
        np.save(tensor_file, self.tensor)

    #载入变量
    def load_preprocessed(self, vocab_file, tensor_file):
        with open(vocab_file, 'rb') as f:
            self.chars = cPickle.load(f)
        self.vocab_size = len(self.chars)
        self.vocab = dict(zip(self.chars, range(len(self.chars))))
        self.tensor = np.load(tensor_file)
        self.num_batches = int(self.tensor.size / (self.batch_size *
                                                   self.seq_length))

    def create_batches(self):
        self.num_batches = int(self.tensor.size / (self.batch_size *
                                                   self.seq_length))

        # When the data (tensor) is too small, let's give them a better error message
        if self.num_batches==0:
            assert False, "Not enough data. Make seq_length and batch_size small."
        #self.num_batches * self.batch_size * self.seq_length
        # 这个东西不就是上面用到的tensor.size吗。。。
        self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length]
        xdata = self.tensor
        ydata = np.copy(self.tensor)
        #回想一下rnn的输入和输出,x和y要错一位,这里没有设置<begin>和<end>
        ydata[:-1] = xdata[1:]
        ydata[-1] = xdata[0]
        #将x和y按batch_size切成了很多batch
        #在这里他们是是有446个batch的list,即[[...],[...],[...],...]
        self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1)
        self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1)


    def next_batch(self):
        x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]
        #pointer每个batch过后向后移动一位
        self.pointer += 1
        return x, y

    def reset_batch_pointer(self):
        self.pointer = 0
python train.py
#更改参数的话举例如下
python train.py —num_layers 2 

看到这里,大家应该已经对这个工程有一个大体的了解了,我们的任务也已经进行了85%了。下面的sample是干什么的呢?因为训练完了我们要看一下效果嘛,所以写了一个sample程序跑一下看看到底我们的模型可以生成什么样的文本,一起来看一下吧。

参考资料:


显示全文