- 2020-07-29 06:35
*views 1*- RNN

import tensorflow as tf from tensorflow.examples.tutorials.mnist import

input_data mnist=input_data.read_data_sets('mnist_data/',one_hot=True)

# Notice that it's used here one_hot express , The shape of the label is (batch_size,num_batches), The type is float, If not one_hot, So the shape of the label is (batch_size,), The type is int

num_classes=10 batch_size=64 hidden_dim1=32 hidden_dim2=64 epochs=10

embedding_dim=28 time_step=28 '''

Here's an analysis tensorflow.nn.dynamic_rnn(cell,inputs,initial_state=None,dtype=None,time_major=False)

There are two return values outputs,states.outputs Is the shape of [batch_size,time_step,cell.output_dim]

The first dimension of this three-dimensional tensor is batch_size, Let's look at two or three dimensions , Equal to every one batch_size Corresponding to a two-dimensional matrix ,

The length of each row of a two-dimensional matrix is rnn Output dimension of , Represents the current batch_size Current in time_step Characteristics of .

take outputs Take transpose and take the last element on the first dimension [-1,batch_size,:], And you get every one of them batch_size Last data in all time_step, that is

Last moment output [batch_size,output_dim],

In the language model ,time_step It's just a sentence , So this tensor [batch_size,output_dim] Usually used as input for the next moment , Because the meaning of each line of this tensor is

The features obtained after reading the whole sentence , This is very meaningful . If outputs After taking transposition [time_step,batch_size,output_dim] Take this tensor [1,batch_size,output_dim]

Now the tensor is just the output after reading the first word in the sentence , If such tensor is used as the input of the next layer, the effect of the model will be terrible .

Let's talk about it now states, If cell yes LSTM Type cell, that states It's an ancestor with two elements , Why are there two elements , because LSTM Type cell There are two states

One is cell state Represents the cellular state of the neuron , The other is hidden state Represents the hidden state of the neuron .

And these two states are the states of the last neuron , therefore states The shape of time_step irrelevant

The shape of both state tensors is [batch_size,output_dim] states[0] yes cell Status of ,states[1] yes hidden Status of

therefore states[1] Extract each with the above mentioned extraction time_step The tensor obtained at the last moment of [batch_size,output_dim] It's the same , Because they are every one of them batch_size in

Feature extraction at the last moment of each data . ''' class RNN_model: def __init__(self): tf.

reset_default_graph() def add_placeholder(self): self.xs=tf.placeholder(shape=[

None,784],dtype=tf.float32) self.ys=tf.placeholder(shape=[None,10],dtype=tf.

float32) # Due to the use of one_hot express , therefore shape yes (batch_size,10),dtype yes float

# no need one_hot It should be changed to self.ys=tf.placeholder(shape=[None],dtype=tf.int32) def

rnn_layer(self): rnn_input=tf.reshape(tensor=self.xs,shape=[-1,time_step,

embedding_dim]) cell_1=tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim1)

cell_2=tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim2) cells=tf.contrib.rnn.

MultiRNNCell([cell_1,cell_2]) initial_state=cells.zero_state(batch_size=

batch_size,dtype=tf.float32) outputs,states=tf.nn.dynamic_rnn(cells,rnn_input,

initial_state=initial_state,time_major=False)

#outputs.shape==(batch_size,time_step,hidden_dim2)

#states Yes batch_size Each length in the time_step The last time state of the data , So with time_step irrelevant

#states[0][0].shape==states[0][1].shape==(batch_size,hidden_dim1)

#states[1][0].shape==states[1][1].shape==(batch_size,hidden_dim2)

#states[1][1]==tf.transpose(outputs,[1,0,2])[-1] outputs=tf.transpose(outputs,

perm=[1,0,2]) self.rnn_output=outputs[-1]#(batch_size,hidden_dim2) def

output_layer(self): weights=tf.Variable(tf.random_normal(shape=[hidden_dim2,

num_classes],dtype=tf.float32)) biases=tf.Variable(tf.random_normal(shape=[

num_classes],dtype=tf.float32)) self.predict=tf.matmul(self.rnn_output,weights)+

biases#(batch_size,num_classes)

# be careful softmax_cross_entropy_with_logits And sparse_softmax_cross_entropy_with_logits The difference between

# formerly logits And labels It has to be the same shape(batch_size,num_classes) And the same thing dtype(float32)

# The latter logits The shape is (batch_size,num_classes),dtype yes float32, and labels Is the shape of (batch_size),dtype It has to be int, Each numerical representation logits What kind of data does each row of data belong to

def loss_layer(self): self.loss=tf.reduce_mean(tf.nn.

softmax_cross_entropy_with_logits(logits=self.predict,labels=self.ys)) self.

train_op=tf.train.AdamOptimizer(0.01).minimize(self.loss) self.accuracy=tf.

reduce_mean(tf.cast(tf.equal(tf.argmax(self.ys,1),tf.argmax(self.predict,1)),

dtype=tf.float32)) def build_graph(self): self.add_placeholder() self.rnn_layer(

) self.output_layer() def train(self): num_batches=mnist.train.num_examples//

batch_size saver=tf.train.Saver() with tf.Session() as sess: sess.run(tf.

global_variables_initializer()) for epoch in range(epochs): epoch_loss=0.0 for i

in range(num_batches): batch_x,batch_y=mnist.train.next_batch(batch_size)

feed_dict={self.xs:batch_x,self.ys:batch_y} _,loss_value=sess.run([self.train_op

,self.loss],feed_dict=feed_dict) epoch_loss+=loss_value.item() test_xs,test_ys=

mnist.test.next_batch(batch_size) assert test_xs.shape==(batch_size,784) and

test_ys.shape==(batch_size,10) acc=sess.run(self.accuracy,feed_dict={self.xs:

test_xs,self.ys:test_ys}) print("After %d epoch,loss value is %f ,and accuracy

is %f " %(epoch+1,epoch_loss/num_batches,acc.item())) saver.save(sess,

"checkpoints/rnn_mnist.ckpt") if __name__=="__main__": model=RNN_model() model.

build_graph() model.loss_layer() model.train()

Technology

Daily Recommendation

views 26

views 2

©2019-2020 Toolsou All rights reserved,

1190 Reverses the substring between each pair of parentheses leetcodemysql Joint index details You don't know ——HarmonyOS Create data mysql Library process Character recognition technology of vehicle license plate based on Neural Network A guess number of small games , use JavaScript realization Talking about uni-app Page value transfer problem pytorch of ResNet18（ Yes cifar10 The accuracy of data classification is achieved 94%）C++ Method of detecting memory leak One is called “ Asking for the train ” A small village Finally got the train