导入所需要的包
(略)
解压加载进来的数据集
!unzip -qo data/data115112/ChineseStyle.zip -d data
label = paddle.to_tensor([1,0], dtype='int64')
one_hot_label =paddle.nn.functional.one_hot(label,num_classes=2)
print(one_hot_label )
制作数据集
#定义ChineseStyleDataset数据集类
class ChineseStyleDataset(Dataset):#构造数据集和标签集def __init__(self,transforms=None,train="train"):super().__init__()self.transforms=transformsself.datas=list() #创建data列表成员,存放图像数据self.labels=list() #创建labels列表成员,存放标签数据#self.temps=list()font_style=[("lishu",0),("xingkai",1)]#遍历train文件夹下的xingkai图片for font_tuple in font_style:font_name=font_tuple[0]font_val=font_tuple[1]font_path="data/ChineseStyle/{}/{}".format(train,font_name)for filename in os.listdir(font_path):if ".ipynb_checkpoints" in filename: #图片文件内会自动包含.ipynb_checkpoints文件,需要排除掉continueimg_path=os.path.join(font_path,filename)#读入图片photo=Image.open(img_path)im=np.array(photo).astype("float32")if im is not None: #防止读到空图片,遇到非空图片才加入数据列表self.datas.append(im)#self.temps.append([img_path,font_val])self.labels.append(np.array(font_val,dtype="int64"))#print(self.temps)def __getitem__(self,index):data=self.datas[index]if self.transforms is not None:data=self.transforms(data)label=self.labels[index]return data,labeldef __len__(self):return len(self.labels) #预处理方案组合
transforms=T.Compose([T.Resize([227,227]),T.Normalize(mean=[0,0,0], std=[255,255,255], data_format='HWC'),T.ToTensor()])#创建数据集实例
#训练数据集
train_dataset=ChineseStyleDataset(transforms=transforms)
#测试数据集
test_dataset=ChineseStyleDataset(transforms=transforms,train="test")
数据生成器
train_dataloader=DataLoader(dataset=train_dataset,shuffle=True,batch_size=200)
test_dataloader=DataLoader(dataset=test_dataset,shuffle=True)
# np.set_printoptions(threshold=np.inf)
# for id,data in enumerate(train_dataloader):
# print(np.array(data[0][0]))
# plt.imshow(np.array(data[0][0]))
# break
构建AlexNet网络
class AlexNet(paddle.nn.Layer):def __init__(self):super().__init__()self.conv_pool1=paddle.nn.Sequential(paddle.nn.Conv2D(in_channels=3,out_channels=96,kernel_size=[11,11],stride=4,padding="valid"),paddle.nn.ReLU(),paddle.nn.MaxPool2D(kernel_size=[3,3],stride=2))self.conv_pool2=paddle.nn.Sequential(paddle.nn.Conv2D(in_channels=96,out_channels=256,kernel_size=[5,5],stride=1,padding="same"),paddle.nn.ReLU(),paddle.nn.MaxPool2D(kernel_size=[3,3],stride=2))self.conv_pool3=paddle.nn.Sequential(paddle.nn.Conv2D(in_channels=256,out_channels=384,kernel_size=[3,3],stride=1,padding="SAME"),paddle.nn.ReLU())self.conv_pool4=paddle.nn.Sequential(paddle.nn.Conv2D(in_channels=384,out_channels=384,kernel_size=[3,3],stride=1,padding="SAME"),paddle.nn.ReLU())self.conv_pool5=paddle.nn.Sequential(paddle.nn.Conv2D(in_channels=384,out_channels=256,kernel_size=[3,3],stride=1,padding="SAME"),paddle.nn.ReLU(),paddle.nn.MaxPool2D(kernel_size=[3,3],stride=2))self.full_con=paddle.nn.Sequential(paddle.nn.Linear(in_features=256*6*6,out_features=4096),paddle.nn.ReLU(),paddle.nn.Dropout(0.5),paddle.nn.Linear(in_features=4096,out_features=4096),paddle.nn.ReLU(),paddle.nn.Dropout(0.5),paddle.nn.Linear(in_features=4096,out_features=2)#,#paddle.nn.Softmax())self.flatten=paddle.nn.Flatten()self.act=paddle.nn.Sigmoid()def forward(self,x):x=self.conv_pool1(x)x=self.conv_pool2(x)x=self.conv_pool3(x)x=self.conv_pool4(x)x=self.conv_pool5(x)x=self.flatten(x)x=self.full_con(x)x=self.act(x)return x
#实例化网络模型
alexNet=AlexNet()
paddle.summary(alexNet,(1,3,227,227))
运行结果:
网络配置、训练、保存
#把模型实例封装高层API的Model对象
model=paddle.Model(alexNet)
#网络配置
model.prepare(optimizer=paddle.optimizer.Adam(parameters=model.parameters(),learning_rate=0.001),loss=paddle.nn.CrossEntropyLoss(),metrics=paddle.metric.Accuracy())
vsDL=paddle.callbacks.VisualDL("log_dir")
model.fit(train_data=train_dataloader,epochs=10,verbose=1,callbacks=vsDL)
model.evaluate(eval_data=test_dataloader,verbose=1)
model.save("mymodel/AlexNet")
运行结果:
im=np.array(Image.open("data/ChineseStyle/train/xingkai/xingkai_1001.jpg")).astype("float32")
#im=paddle.to_tensor(im)
im=im.reshape([1,3,256,256])
alexNet=AlexNet()
model=paddle.Model(alexNet)
model.load("mymodel/AlexNet")
result=model.predict_batch(im)
print(result)
运行结果: