import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

from timeit import default_timer as timer

#本地读取数据集函数
def load_mnist(path):

    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)

#运行时间统计
#运行开始时间
tic = timer()

#准备数据
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
path = r'/home/techuser/Downloads/mnist.npz'  # 指定本地数据集。注意斜杠
# 读取本地数据及
(x_train, y_train), (x_test, y_test) = load_mnist(path)

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

#构建模型
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

#训练模型
batch_size = 128
epochs = 10
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

#预测训练模型
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

#运行结束时间
toc = timer()
print("运行的时间是：%ss" %(toc - tic))

#将运行结果存放在filepath文件
filepath = r'/home/techuser/result2.txt'
f = open(filepath,'w',encoding='utf-8')
f.write("Test loss: %s \n"%score[0])
f.write("Test accuracy: %s \n"%score[1])
f.write("运行的时间是：%ss" %(toc - tic))
f.close()
