mxnet 神经网络训练和预测

https://mxnet.incubator.apache.org/tutorials/basic/module.html

import loggingimport randomlogging.getLogger().setLevel(logging.INFO)import mxnet as mximport numpy as npmx.random.seed(1234)np.random.seed(1234)random.seed(1234)# 准备数据fname = mx.test_utils.download(https://s3.us-east-2.amazonaws.com/mxnet-public/letter_recognition/letter-recognition.data)data = np.genfromtxt(fname=fname,delimiter=,)[:,1:]label = np.array([ord(l.split(,)[0])-ord(A) for l in open(fname, r)])batch_size = 32ntrain = int(data.shape[0]*0.8)train_iter = mx.io.NDArrayIter(data[:ntrain,:],label[:ntrain],batch_size,shuffle=True)val_iter = mx.io.NDArrayIter(data[ntrain:,:],label[ntrain:],batch_size)# 定义网络net = mx.sym.Variable(data)net = mx.sym.FullyConnected(net, name=fc1, num_hidden=64)net = mx.sym.Activation(net, name=relu1, act_type="relu")net = mx.sym.FullyConnected(net, name=fc2, num_hidden=26)net = mx.sym.SoftmaxOutput(net, name=softmax)mx.viz.plot_network(net, node_attrs={"shape":"oval","fixedsize":"false"})# # 创建模块mod = mx.mod.Module(symbol=net, context=mx.cpu(), data_names=[data], label_names=[softmax_label])# # 中层接口# # 训练模型# mod.bind(data_shapes=train_iter.provide_data,label_shapes=train_iter.provide_label)# mod.init_params(initializer=mx.init.Uniform(scale=.1))# mod.init_optimizer(optimizer=‘sgd‘,optimizer_params=((‘learning_rate‘,0.1),))# metric = mx.metric.create(‘acc‘)## for epoch in range(100):# train_iter.reset()# metric.reset()# for batch in train_iter:# mod.forward(batch,is_train=True)# mod.update_metric(metric,batch.label)# mod.backward()# mod.update()# print(‘Epoch %d,Training %s‘ % (epoch,metric.get()))# fit 高层接口train_iter.reset()mod = mx.mod.Module(symbol=net, context=mx.cpu(), data_names=[data], label_names=[softmax_label])mod.fit(train_iter, eval_data=val_iter, optimizer=sgd, optimizer_params={learning_rate:0.1}, eval_metric=acc, num_epoch=10)# 预测和评估y = mod.predict(val_iter)assert y.shape == (4000,26)# 评分score = mod.score(val_iter,[acc])print("Accuracy score is %f"%(score[0][1]))assert score[0][1] > 0.76, "Achieved accuracy (%f) is less than expected (0.76)" % score[0][1]# 保存和加载# 构造一个回调函数保存检查点model_prefix = mx_mlpcheckpoint = mx.callback.do_checkpoint(model_prefix)mod = mx.mod.Module(symbol=net)mod.fit(train_iter,num_epoch=5,epoch_end_callback=checkpoint)sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)assert sym.tojson() == net.tojson()# assign the loaded parameters to the modulemod.set_params(arg_params, aux_params)mod = mx.mod.Module(symbol=sym)mod.fit(train_iter, num_epoch=21, arg_params=arg_params, aux_params=aux_params, begin_epoch=3)assert score[0][1] > 0.77, "Achieved accuracy (%f) is less than expected (0.77)" % score[0][1]

 

相关文章