encode与decode

import torchfrom torch import nnimport numpy as npimport matplotlib.pyplot as pltimport torch.utils.data as Dataimport torchvisionfrom mpl_toolkits.mplot3d import Axes3D #画3D图from matplotlib import cm# Hyper ParametersEPOCH=10BATCH_SIZE=64LR = 0.005 # learning rateDOWNLOAD_MNIST=FalseN_TEST_IMG=5train_data=torchvision.datasets.MNIST( root=./mnist/, train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST)train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)class AutoEncoder(nn.Module): def __init__(self): super(AutoEncoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28 * 28, 128), nn.Tanh(), nn.Linear(128,64), nn.Tanh(), nn.Linear(64, 12), # nn.Tanh(), # nn.Linear(12, 3), ) self.decoder=nn.Sequential( # nn.Linear(3,12), # nn.Tanh(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 28*28), nn.Sigmoid() ) def forward(self, x ): encoder=self.encoder(x) decoder=self.decoder(encoder) return encoder,decoderAutoEncoder = AutoEncoder()# print(AutoEncoder)optimizer = torch.optim.Adam(AutoEncoder.parameters(), lr=LR) # optimize all cnn parametersloss_func = nn.MSELoss()f,a=plt.subplots(2,N_TEST_IMG,figsize=(5,2))plt.ion() # continuously plotview_data=train_data.train_data[:N_TEST_IMG].view(-1,28*28).type(torch.FloatTensor)/255for i in range(N_TEST_IMG): a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap=gray) a[0][i].set_xticks(()) a[0][i].set_yticks(())for epoch in range(EPOCH): for step,(x,b_label) in enumerate(train_loader): b_x=x.view(-1,28*28) b_y=x.view(-1,28*28) encoded, decoded = AutoEncoder(b_x) loss=loss_func(decoded,b_y) optimizer.zero_grad() loss.backward() optimizer.step() if step%100==0: print(Epoch:|,epoch,train loss:%0.4f%loss.data.numpy()) _,decoded_data=AutoEncoder(view_data) for i in range(N_TEST_IMG): a[1][i].clear() a[1][i].imshow(np.reshape(decoded.data.numpy()[i],(28,28)),cmap=gray) a[1][i].set_xticks(()) a[1][i].set_yticks(()) plt.draw() plt.pause(0.05)plt.ioff()plt.show()view_data=train_data.train_data[:200].view(-1,28*28).type(torch.FloatTensor)/255encoded_data,_=AutoEncoder(view_data)fig=plt.figure(2)ax=Axes3D(fig)X,Y,Z=encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()values=train_data.train_labels[:200].numpy()for x,y,z ,s in zip(X,Y,Z,values): c=cm.rainbow(int(255*s/9)) ax.text(x,y,z,s,backgroundcolor=c)ax.set_xlim(X.min(),X.max())ax.set_ylim(Y.min(),Y.max())ax.set_zlim(Z.min(),Z.max())plt.show()

选出五张图片做测试。

图像分为5*2显示,上面一行是原始图像,下面一行为编码和解码后的图像。

相关文章