AI安全之对抗样本入门
上QQ阅读APP看书,第一时间看更新

1.4.1 测试数据

我们以Scikit-Learn环境介绍常见的性能衡量指标。为了便于演示,我们创建测试数据,测试数据一共有1000条记录,每条记录了100个特征,内容随机生成:

x, y = datasets.make_classification(n_samples=1000, n_features=100,
                    n_redundant=0, random_state = 1)

把数据集随机划分成训练集和测试集,其中测试集占40%:

train_X, test_X, train_y, test_y = train_test_split(x,
                                         y,
                                         test_size=0.2,
                                         random_state=66)

使用KNN算法进行训练和预测:

knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(train_X, train_Y)
pred_Y = knn.predict(test_X)