- 2020-07-29 06:35
*views 2*- RNN

import tensorflow as tf from tensorflow.examples.tutorials.mnist import

input_data mnist=input_data.read_data_sets('mnist_data/',one_hot=True)

# Notice that it's used here one_hot express , The shape of the label is (batch_size,num_batches), The type is float, If not one_hot, So the shape of the label is (batch_size,), The type is int

num_classes=10 batch_size=64 hidden_dim1=32 hidden_dim2=64 epochs=10

embedding_dim=28 time_step=28 '''

Here's an analysis tensorflow.nn.dynamic_rnn(cell,inputs,initial_state=None,dtype=None,time_major=False)

There are two return values outputs,states.outputs Is the shape of [batch_size,time_step,cell.output_dim]

The first dimension of this three-dimensional tensor is batch_size, Let's look at two or three dimensions , Equal to every one batch_size Corresponding to a two-dimensional matrix ,

The length of each row of a two-dimensional matrix is rnn Output dimension of , Represents the current batch_size Current in time_step Characteristics of .

take outputs Take transpose and take the last element on the first dimension [-1,batch_size,:], And you get every one of them batch_size Last data in all time_step, that is

Last moment output [batch_size,output_dim],

In the language model ,time_step It's just a sentence , So this tensor [batch_size,output_dim] Usually used as input for the next moment , Because the meaning of each line of this tensor is

The features obtained after reading the whole sentence , This is very meaningful . If outputs After taking transposition [time_step,batch_size,output_dim] Take this tensor [1,batch_size,output_dim]

Now the tensor is just the output after reading the first word in the sentence , If such tensor is used as the input of the next layer, the effect of the model will be terrible .

Let's talk about it now states, If cell yes LSTM Type cell, that states It's an ancestor with two elements , Why are there two elements , because LSTM Type cell There are two states

One is cell state Represents the cellular state of the neuron , The other is hidden state Represents the hidden state of the neuron .

And these two states are the states of the last neuron , therefore states The shape of time_step irrelevant

The shape of both state tensors is [batch_size,output_dim] states[0] yes cell Status of ,states[1] yes hidden Status of

therefore states[1] Extract each with the above mentioned extraction time_step The tensor obtained at the last moment of [batch_size,output_dim] It's the same , Because they are every one of them batch_size in

Feature extraction at the last moment of each data . ''' class RNN_model: def __init__(self): tf.

reset_default_graph() def add_placeholder(self): self.xs=tf.placeholder(shape=[

None,784],dtype=tf.float32) self.ys=tf.placeholder(shape=[None,10],dtype=tf.

float32) # Due to the use of one_hot express , therefore shape yes (batch_size,10),dtype yes float

# no need one_hot It should be changed to self.ys=tf.placeholder(shape=[None],dtype=tf.int32) def

rnn_layer(self): rnn_input=tf.reshape(tensor=self.xs,shape=[-1,time_step,

embedding_dim]) cell_1=tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim1)

cell_2=tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim2) cells=tf.contrib.rnn.

MultiRNNCell([cell_1,cell_2]) initial_state=cells.zero_state(batch_size=

batch_size,dtype=tf.float32) outputs,states=tf.nn.dynamic_rnn(cells,rnn_input,

initial_state=initial_state,time_major=False)

#outputs.shape==(batch_size,time_step,hidden_dim2)

#states Yes batch_size Each length in the time_step The last time state of the data , So with time_step irrelevant

#states[0][0].shape==states[0][1].shape==(batch_size,hidden_dim1)

#states[1][0].shape==states[1][1].shape==(batch_size,hidden_dim2)

#states[1][1]==tf.transpose(outputs,[1,0,2])[-1] outputs=tf.transpose(outputs,

perm=[1,0,2]) self.rnn_output=outputs[-1]#(batch_size,hidden_dim2) def

output_layer(self): weights=tf.Variable(tf.random_normal(shape=[hidden_dim2,

num_classes],dtype=tf.float32)) biases=tf.Variable(tf.random_normal(shape=[

num_classes],dtype=tf.float32)) self.predict=tf.matmul(self.rnn_output,weights)+

biases#(batch_size,num_classes)

# be careful softmax_cross_entropy_with_logits And sparse_softmax_cross_entropy_with_logits The difference between

# formerly logits And labels It has to be the same shape(batch_size,num_classes) And the same thing dtype(float32)

# The latter logits The shape is (batch_size,num_classes),dtype yes float32, and labels Is the shape of (batch_size),dtype It has to be int, Each numerical representation logits What kind of data does each row of data belong to

def loss_layer(self): self.loss=tf.reduce_mean(tf.nn.

softmax_cross_entropy_with_logits(logits=self.predict,labels=self.ys)) self.

train_op=tf.train.AdamOptimizer(0.01).minimize(self.loss) self.accuracy=tf.

reduce_mean(tf.cast(tf.equal(tf.argmax(self.ys,1),tf.argmax(self.predict,1)),

dtype=tf.float32)) def build_graph(self): self.add_placeholder() self.rnn_layer(

) self.output_layer() def train(self): num_batches=mnist.train.num_examples//

batch_size saver=tf.train.Saver() with tf.Session() as sess: sess.run(tf.

global_variables_initializer()) for epoch in range(epochs): epoch_loss=0.0 for i

in range(num_batches): batch_x,batch_y=mnist.train.next_batch(batch_size)

feed_dict={self.xs:batch_x,self.ys:batch_y} _,loss_value=sess.run([self.train_op

,self.loss],feed_dict=feed_dict) epoch_loss+=loss_value.item() test_xs,test_ys=

mnist.test.next_batch(batch_size) assert test_xs.shape==(batch_size,784) and

test_ys.shape==(batch_size,10) acc=sess.run(self.accuracy,feed_dict={self.xs:

test_xs,self.ys:test_ys}) print("After %d epoch,loss value is %f ,and accuracy

is %f " %(epoch+1,epoch_loss/num_batches,acc.item())) saver.save(sess,

"checkpoints/rnn_mnist.ckpt") if __name__=="__main__": model=RNN_model() model.

build_graph() model.loss_layer() model.train()

Technology

- Java393 articles
- Python205 articles
- Linux112 articles
- Vue98 articles
- MySQL85 articles
- SpringBoot70 articles
- javascript65 articles
- Spring63 articles
- more...

Daily Recommendation

views 2

©2019-2020 Toolsou All rights reserved,

Non preemptive static priority scheduling algorithm for operating system （C language ）Go Language learning notes （GUI programming ）XCTF Attack and defense world web Advanced practice _ 2_lottery What's the difference between computer major and training background ?python realization vlookup_ Dry goods I ： Why python It's inside vlookup Bubble sort primary springboot2 Separation of front and rear platforms ,token Put in header Pit for verification Python Case conversion of letters （ Two methods ）javascript event （ Detailed explanation of zero basis ）Unity2019 UIElement note （ ten ） Simple exercise 2