TensorFlow 1.x Deep Learning Cookbook
上QQ阅读APP看书,第一时间看更新

How to do it…

We proceed with the recipe as follows:

  1. Import the modules required and declare global variables:
import tensorflow as tf

# Global parameters
DATA_FILE = 'boston_housing.csv'
BATCH_SIZE = 10
NUM_FEATURES = 14
  1. Next, we define a function that will take as argument the filename and return tensors in batches of size equal to BATCH_SIZE:
def data_generator(filename):
"""
Generates Tensors in batches of size Batch_SIZE.
Args: String Tensor
Filename from which data is to be read
Returns: Tensors
feature_batch and label_batch
"""
  1. Define the filename that is f_queue and reader:
f_queue = tf.train.string_input_producer(filename)
reader = tf.TextLineReader(skip_header_lines=1) # Skips the first line
_, value = reader.read(f_queue)
  1. We specify the data to use in case data is missing. Decode the .csv and select the features we need. For the example, we choose RM, PTRATIO, and LSTAT:
record_defaults = [ [0.0] for _ in range(NUM_FEATURES)]
data = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.stack(tf.gather_nd(data,[[5],[10],[12]]))
label = data[-1]
  1. Define parameters to generate batch and use tf.train.shuffle_batch() for randomly shuffling the tensors. The function returns the tensors--feature_batch and label_batch:
# minimum number elements in the queue after a dequeuemin_after_dequeue = 10 * BATCH_SIZE

# the maximum number of elements in the queue
capacity = 20 * BATCH_SIZE

# shuffle the data to generate BATCH_SIZE sample pairs
feature_batch, label_batch = tf.train.shuffle_batch([features, label], batch_size=BATCH_SIZE,
capacity=capacity, min_after_dequeue=min_after_dequeue)

return feature_batch, label_batch
  1. We define another function to generate the batches in the session:
def generate_data(feature_batch, label_batch):
with tf.Session() as sess:
# intialize the queue threads
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for _ in range(5): # Generate 5 batches
features, labels = sess.run([feature_batch, label_batch])
print (features, "HI")
coord.request_stop()
coord.join(threads)
  1. Now, we can use these two functions to get the data in batches. Here, we are just printing the data; when learning, we will perform the optimization step at this point:
if __name__ =='__main__':
feature_batch, label_batch = data_generator([DATA_FILE])
generate_data(feature_batch, label_batch)