文章修改至 https://tensorflow.google.cn/tutorials/load_data/csv
设置和获得数据
import需要的库:
- functools:适合于高阶函数,作用于或返回其他函数的函数,一般来说,对于该模块,任何可调用对象都可以视为一个函数
- numpy:矩阵处理
- tensorflow_datasets:
- pandas:数据分析
import functools
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import pandas as pd
获得数据
# 指定数据的所在地
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
TEST_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/eval.csv"
# 从相应的所在地下载数据,并获得下载数据所在的路径
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
test_file_path = tf.keras.utils.get_file("eval.csv", TEST_DATA_URL)
# numpy输出时,只保留到元素的小数点后3位,并且当数值小于当前精度时,则丢弃小数点后超过精度的值
np.set_printoptions(precision=3, suppress=True)
加载数据
打印最后5行CSV文件信息,了解文件的格式,其中“survived”为标签列
raw_dataset = pd.read_csv(train_file_path,na_values="?")
dataset = raw_dataset.copy()
dataset.tail()
正如你看到的那样,CSV 文件的每列都会有一个列名。dataset 的构造函数会自动识别这些列名。如果你使用的文件的第一行不包含列名,那么需要将列名通过字符串列表传给 make_csv_dataset 函数的 column_names 参数。
# 对于包含模型需要预测的值的列是你需要显式指定的。
LABEL_COLUMN = 'survived'
LABELS = [0, 1]
def