博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow bilstm官方示例
阅读量:6688 次
发布时间:2019-06-25

本文共 6581 字,大约阅读时间需要 21 分钟。

1 '''  2 A Bidirectional Recurrent Neural Network (LSTM) implementation example using TensorFlow library.  3 This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)  4 Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf  5   6 Author: Aymeric Damien  7 Project: https://github.com/aymericdamien/TensorFlow-Examples/  8 '''  9  10 from __future__ import print_function 11  12 import tensorflow as tf 13 from tensorflow.contrib import rnn 14 import numpy as np 15  16 # Import MNIST data 17 from tensorflow.examples.tutorials.mnist import input_data 18 mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) 19  20 ''' 21 To classify images using a bidirectional recurrent neural network, we consider 22 every image row as a sequence of pixels. Because MNIST image shape is 28*28px, 23 we will then handle 28 sequences of 28 steps for every sample. 24 ''' 25  26 # Parameters 27 learning_rate = 0.001 28  29 # 可以理解为,训练时总共用的样本数 30 training_iters = 100000 31  32 # 每次训练的样本大小 33 batch_size = 128 34  35 # 这个是用来显示的。 36 display_step = 10 37  38 # Network Parameters 39 # n_steps*n_input其实就是那张图 把每一行拆到每个time step上。 40 n_input = 28 # MNIST data input (img shape: 28*28) 41 n_steps = 28 # timesteps 42  43 # 隐藏层大小 44 n_hidden = 128 # hidden layer num of features 45 n_classes = 10 # MNIST total classes (0-9 digits) 46  47 # tf Graph input 48 # [None, n_steps, n_input]这个None表示这一维不确定大小 49 x = tf.placeholder("float", [None, n_steps, n_input]) 50 y = tf.placeholder("float", [None, n_classes]) 51  52 # Define weights 53 weights = { 54     # Hidden layer weights => 2*n_hidden because of forward + backward cells 55     'out': tf.Variable(tf.random_normal([2*n_hidden, n_classes])) 56 } 57 biases = { 58     'out': tf.Variable(tf.random_normal([n_classes])) 59 } 60  61  62 def BiRNN(x, weights, biases): 63  64     # Prepare data shape to match `bidirectional_rnn` function requirements 65     # Current data input shape: (batch_size, n_steps, n_input) 66     # Required shape: 'n_steps' tensors list of shape (batch_size, n_input) 67  68     # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input) 69     # 变成了n_steps*(batch_size, n_input) 70     x = tf.unstack(x, n_steps, 1) 71  72     # Define lstm cells with tensorflow 73     # Forward direction cell 74     lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) 75     # Backward direction cell 76     lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) 77  78     # Get lstm cell output 79     try: 80         outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, 81                                               dtype=tf.float32) 82     except Exception: # Old TensorFlow version only returns outputs not states 83         outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, 84                                         dtype=tf.float32) 85  86     # Linear activation, using rnn inner loop last output 87     return tf.matmul(outputs[-1], weights['out']) + biases['out'] 88  89 pred = BiRNN(x, weights, biases) 90  91 # Define loss and optimizer 92 # softmax_cross_entropy_with_logits:Measures the probability error in discrete classification tasks in which the classes are mutually exclusive 93 # return a 1-D Tensor of length batch_size of the same type as logits with the softmax cross entropy loss. 94 # reduce_mean就是对所有数值(这里没有指定哪一维)求均值。 95 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 96 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) 97  98 # Evaluate model 99 correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))100 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))101 102 # Initializing the variables103 init = tf.global_variables_initializer()104 105 # Launch the graph106 with tf.Session() as sess:107     sess.run(init)108     step = 1109     # Keep training until reach max iterations110     while step * batch_size < training_iters:111         batch_x, batch_y = mnist.train.next_batch(batch_size)112         # Reshape data to get 28 seq of 28 elements113         batch_x = batch_x.reshape((batch_size, n_steps, n_input))114         # Run optimization op (backprop)115         sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})116         if step % display_step == 0:117             # Calculate batch accuracy118             acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})119             # Calculate batch loss120             loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})121             print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \122                   "{:.6f}".format(loss) + ", Training Accuracy= " + \123                   "{:.5f}".format(acc))124         step += 1125     print("Optimization Finished!")126 127     # Calculate accuracy for 128 mnist test images128     test_len = 128129     test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))130     test_label = mnist.test.labels[:test_len]131     print("Testing Accuracy:", \132         sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

官方关于bilstm的例子写的很清楚了。因为是第一次看,还是要查许多东西。尤其是数据处理方面。

数据的处理()

拼接

t1 = [[1, 2, 3], [4, 5, 6]]t2 = [[7, 8, 9], [10, 11, 12]]tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]tf.stack([t1, t2], 0)  ==> [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]tf.stack([t1, t2], 1)  ==> [[[1, 2, 3], [7, 8, 9]], [[4, 5, 6], [10, 11, 12]]]tf.stack([t1, t2], 2)  ==> [[[1, 7], [2, 8], [3, 9]], [[4, 10], [5, 11], [6, 12]]]

从shape的角度看:

t1 = [[1, 2, 3], [4, 5, 6]]t2 = [[7, 8, 9], [10, 11, 12]]tf.concat([t1, t2], 0)  # [2,3] + [2,3] ==> [4, 3]tf.concat([t1, t2], 1)  # [2,3] + [2,3] ==> [2, 6]tf.stack([t1, t2], 0)   # [2,3] + [2,3] ==> [2*,2,3]tf.stack([t1, t2], 1)   # [2,3] + [2,3] ==> [2,2*,3]tf.stack([t1, t2], 2)   # [2,3] + [2,3] ==> [2,3,2*]

抽取:

input = [[[1, 1, 1], [2, 2, 2]],         [[3, 3, 3], [4, 4, 4]],         [[5, 5, 5], [6, 6, 6]]]tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],                                            [4, 4, 4]]]tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],                                           [[5, 5, 5]]]                                           tf.gather(input, [0, 2]) ==> [[[1, 1, 1], [2, 2, 2]],                              [[5, 5, 5], [6, 6, 6]]]

 

转载于:https://www.cnblogs.com/linyx/p/6979119.html

你可能感兴趣的文章
接口(工厂模式&代理模式)
查看>>
3月个人随笔
查看>>
netty入门
查看>>
iOS 组件化流程详解(git创建流程)
查看>>
搜索关键字高亮显示,就比微信多个多音字搜索
查看>>
1303: [CQOI2009]中位数图
查看>>
在数组的开头插入键值对
查看>>
LTTng
查看>>
常用模块
查看>>
HTTPS = HTTP + SSL
查看>>
Copy修饰的NSArray
查看>>
eclipse新建web项目
查看>>
gnuplot
查看>>
GraphQL(三):GraphQL集成SpringBoot原理
查看>>
Balloons
查看>>
posix消息队列(1)
查看>>
using for jekyll
查看>>
Jenkins的授权和访问控制
查看>>
C/C++里的const(1)
查看>>
Expert 诊断优化系列-------------针对重点语句调索引
查看>>