1-1.结构化数据建模流程范例
import os
#mac系统上pytorch和matplotlib在jupyter中同时跑需要更改环境变量
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
!pip install torch==2.0.0
!pip install -U torchkeras
import torch
import torchkeras
print("torch.__version__ = ", torch.__version__)
print("torchkeras.__version__ = ", torchkeras.__version__)
torch.__version__ = 2.0.1
torchkeras.__version__ = 3.9.3
一,准备数据
titanic数据集的目标是根据乘客信息预测他们在Titanic号撞击冰山沉没后能否生存
结构化数据一般会使用Pandas中的DataFrame进行预处理。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader,TensorDataset
dftrain_raw = pd.read_csv('./eat_pytorch_datasets/titanic/train.csv')
dftest_raw = pd.read_csv('./eat_pytorch_datasets/titanic/test.csv')
dftrain_raw.head(10)
PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 493 | 0 | 1 | Molson, Mr. Harry Markland | male | 55.0 | 0 | 0 | 113787 | 30.5000 | C30 | S |
1 | 53 | 1 | 1 | Harper, Mrs. Henry Sleeper (Myna Haxtun) | female | 49.0 | 1 | 0 | PC 17572 | 76.7292 | D33 | C |
2 | 388 | 1 | 2 | Buss, Miss. Kate | female | 36.0 | 0 | 0 | 27849 | 13.0000 | NaN | S |
3 | 192 | 0 | 2 | Carbines, Mr. William | male | 19.0 | 0 | 0 | 28424 | 13.0000 | NaN | S |
4 | 687 | 0 | 3 | Panula, Mr. Jaako Arnold | male | 14.0 | 4 | 1 | 3101295 | 39.6875 | NaN | S |
5 | 16 | 1 | 2 | Hewlett, Mrs. (Mary D Kingcome) | female | 55.0 | 0 | 0 | 248706 | 16.0000 | NaN | S |
6 | 228 | 0 | 3 | Lovell, Mr. John Hall ("Henry") | male | 20.5 | 0 | 0 | A/5 21173 | 7.2500 | NaN | S |
7 | 884 | 0 | 2 | Banfield, Mr. Frederick James | male | 28.0 | 0 | 0 | C.A./SOTON 34068 | 10.5000 | NaN | S |
8 | 168 | 0 | 3 | Skoog, Mrs. William (Anna Bernhardina Karlsson) | female | 45.0 | 1 | 4 | 347088 | 27.9000 | NaN | S |
9 | 752 | 1 | 3 | Moor, Master. Meier | male | 6.0 | 0 | 1 | 392096 | 12.4750 | E121 | S |
字段说明:
- Survived:0代表死亡,1代表存活【y标签】
- Pclass:乘客所持票类,有三种值(1,2,3) 【转换成onehot编码】
- Name:乘客姓名 【舍去】
- Sex:乘客性别 【转换成bool特征】
- Age:乘客年龄(有缺失) 【数值特征,添加“年龄是否缺失”作为辅助特征】
- SibSp:乘客兄弟姐妹/配偶的个数(整数值) 【数值特征】
- Parch:乘客父母/孩子的个数(整数值)【数值特征】
- Ticket:票号(字符串)【舍去】
- Fare:乘客所持票的价格(浮点数,0-500不等) 【数值特征】
- Cabin:乘客所在船舱(有缺失) 【添加“所在船舱是否缺失”作为辅助特征】
- Embarked:乘客登船港口:S、C、Q(有缺失)【转换成onehot编码,四维度 S,C,Q,nan】
利用Pandas的数据可视化功能我们可以简单地进行探索性数据分析EDA(Exploratory Data Analysis)。
label分布情况
%matplotlib inline
%config InlineBackend.figure_format = 'png'
ax = dftrain_raw['Survived'].value_counts().plot(kind = 'bar',
figsize = (12,8),fontsize=15,rot = 0)
ax.set_ylabel('Counts',fontsize = 15)
ax.set_xlabel('Survived',fontsize = 15)
plt.show()
年龄分布情况
%matplotlib inline
%config InlineBackend.figure_format = 'png'
ax = dftrain_raw['Age'].plot(kind = 'hist',bins = 20,color= 'purple',
figsize = (12,8),fontsize=15)
ax.set_ylabel('Frequency',fontsize = 15)
ax.set_xlabel('Age',fontsize = 15)
plt.show()
年龄和label的相关性
%matplotlib inline
%config InlineBackend.figure_format = 'png'
ax = dftrain_raw.query('Survived == 0')['Age'].plot(kind = 'density',
figsize = (12,8),fontsize=15)
dftrain_raw.query('Survived == 1')['Age'].plot(kind = 'density',
figsize = (12,8),fontsize=15)
ax.legend(['Survived==0','Survived==1'],fontsize = 12)
ax.set_ylabel('Density',fontsize = 15)
ax.set_xlabel('Age',fontsize = 15)
plt.show()
下面为正式的数据预处理
def preprocessing(dfdata):
dfresult= pd.DataFrame()
#Pclass
dfPclass = pd.get_dummies(dfdata['Pclass'])
dfPclass.columns = ['Pclass_' +str(x) for x in dfPclass.columns ]
dfresult = pd.concat([dfresult,dfPclass],axis = 1)
#Sex
dfSex = pd.get_dummies(dfdata['Sex'])
dfresult = pd.concat([dfresult,dfSex],axis = 1)
#Age
dfresult['Age'] = dfdata['Age'].fillna(0)
dfresult['Age_null'] = pd.isna(dfdata['Age']).astype('int32')
#SibSp,Parch,Fare
dfresult['SibSp'] = dfdata['SibSp']
dfresult['Parch'] = dfdata['Parch']
dfresult['Fare'] = dfdata['Fare']
#Carbin
dfresult['Cabin_null'] = pd.isna(dfdata['Cabin']).astype('int32')
#Embarked
dfEmbarked = pd.get_dummies(dfdata['Embarked'],dummy_na=True)
dfEmbarked.columns = ['Embarked_' + str(x) for x in dfEmbarked.columns]
dfresult = pd.concat([dfresult,dfEmbarked],axis = 1)
return(dfresult)
x_train = preprocessing(dftrain_raw).values
y_train = dftrain_raw[['Survived']].values
x_test = preprocessing(dftest_raw).values
y_test = dftest_raw[['Survived']].values
print("x_train.shape =", x_train.shape )
print("x_test.shape =", x_test.shape )
print("y_train.shape =", y_train.shape )
print("y_test.shape =", y_test.shape )
x_train.shape = (712, 15)
x_test.shape = (179, 15)
y_train.shape = (712, 1)
y_test.shape = (179, 1)
进一步使用DataLoader和TensorDataset封装成可以迭代的数据管道。
dl_train = DataLoader(TensorDataset(torch.tensor(x_train).float(),torch.tensor(y_train).float()),
shuffle = True, batch_size = 8)
dl_val = DataLoader(TensorDataset(torch.tensor(x_test).float(),torch.tensor(y_test).float()),
shuffle = False, batch_size = 8)
# 测试数据管道
for features,labels in dl_train:
print(features,labels)
break
tensor([[ 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 21.0000, 0.0000, 1.0000,
0.0000, 9.8250, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000],
[ 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 1.0000,
0.0000, 89.1042, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
[ 1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 65.0000, 0.0000, 0.0000,
1.0000, 61.9792, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
[ 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 36.0000, 0.0000, 0.0000,
2.0000, 71.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
[ 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 32.5000, 0.0000, 0.0000,
0.0000, 13.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
[ 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 45.0000, 0.0000, 0.0000,
0.0000, 7.7500, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000],
[ 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 1.0000, 2.0000,
0.0000, 23.2500, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000],
[ 1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 80.0000, 0.0000, 0.0000,
0.0000, 30.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000]]) tensor([[0.],
[1.],
[0.],
[1.],
[1.],
[0.],
[1.],
[1.]])
二,定义模型
使用Pytorch通常有三种方式构建模型:使用nn.Sequential按层顺序构建模型,继承nn.Module基类构建自定义模型,继承nn.Module基类构建模型并辅助应用模型容器进行封装。
此处选择使用最简单的nn.Sequential,按层顺序模型。
def create_net():
net = nn.Sequential()
net.add_module("linear1",nn.Linear(15,20))
net.add_module("relu1",nn.ReLU())
net.add_module("linear2",nn.Linear(20,15))
net.add_module("relu2",nn.ReLU())
net.add_module("linear3",nn.Linear(15,1))
return net
net = create_net()
print(net)
Sequential(
(linear1): Linear(in_features=15, out_features=20, bias=True)
(relu1): ReLU()
(linear2): Linear(in_features=20, out_features=15, bias=True)
(relu2): ReLU()
(linear3): Linear(in_features=15, out_features=1, bias=True)
)
三,训练模型
Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。
有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环。
此处介绍一种较通用的仿照Keras风格的脚本形式的训练循环。
该脚本形式的训练代码与 torchkeras 库的核心代码基本一致。
torchkeras详情: https://github.com/lyhue1991/torchkeras
import os,sys,time
import numpy as np
import pandas as pd
import datetime
from tqdm import tqdm
import torch
from torch import nn
from copy import deepcopy
from torchkeras.metrics import Accuracy
def printlog(info):
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("\n"+"=========="*8 + "%s"%nowtime)
print(str(info)+"\n")
loss_fn = nn.BCEWithLogitsLoss()
optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)
metrics_dict = {"acc":Accuracy()}
epochs = 20
ckpt_path='checkpoint.pt'
#early_stopping相关设置
monitor="val_acc"
patience=5
mode="max"
history = {}
for epoch in range(1, epochs+1):
printlog("Epoch {0} / {1}".format(epoch, epochs))
# 1,train -------------------------------------------------
net.train()
total_loss,step = 0,0
loop = tqdm(enumerate(dl_train), total =len(dl_train),file = sys.stdout)
train_metrics_dict = deepcopy(metrics_dict)
for i, batch in loop:
features,labels = batch
#forward
preds = net(features)
loss = loss_fn(preds,labels)
#backward
loss.backward()
optimizer.step()
optimizer.zero_grad()
#metrics
step_metrics = {"train_"+name:metric_fn(preds, labels).item()
for name,metric_fn in train_metrics_dict.items()}
step_log = dict({"train_loss":loss.item()},**step_metrics)
total_loss += loss.item()
step+=1
if i!=len(dl_train)-1:
loop.set_postfix(**step_log)
else:
epoch_loss = total_loss/step
epoch_metrics = {"train_"+name:metric_fn.compute().item()
for name,metric_fn in train_metrics_dict.items()}
epoch_log = dict({"train_loss":epoch_loss},**epoch_metrics)
loop.set_postfix(**epoch_log)
for name,metric_fn in train_metrics_dict.items():
metric_fn.reset()
for name, metric in epoch_log.items():
history[name] = history.get(name, []) + [metric]
# 2,validate -------------------------------------------------
net.eval()
total_loss,step = 0,0
loop = tqdm(enumerate(dl_val), total =len(dl_val),file = sys.stdout)
val_metrics_dict = deepcopy(metrics_dict)
with torch.no_grad():
for i, batch in loop:
features,labels = batch
#forward
preds = net(features)
loss = loss_fn(preds,labels)
#metrics
step_metrics = {"val_"+name:metric_fn(preds, labels).item()
for name,metric_fn in val_metrics_dict.items()}
step_log = dict({"val_loss":loss.item()},**step_metrics)
total_loss += loss.item()
step+=1
if i!=len(dl_val)-1:
loop.set_postfix(**step_log)
else:
epoch_loss = (total_loss/step)
epoch_metrics = {"val_"+name:metric_fn.compute().item()
for name,metric_fn in val_metrics_dict.items()}
epoch_log = dict({"val_loss":epoch_loss},**epoch_metrics)
loop.set_postfix(**epoch_log)
for name,metric_fn in val_metrics_dict.items():
metric_fn.reset()
epoch_log["epoch"] = epoch
for name, metric in epoch_log.items():
history[name] = history.get(name, []) + [metric]
# 3,early-stopping -------------------------------------------------
arr_scores = history[monitor]
best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)
if best_score_idx==len(arr_scores)-1:
torch.save(net.state_dict(),ckpt_path)
print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor,
arr_scores[best_score_idx]),file=sys.stderr)
if len(arr_scores)-best_score_idx>patience:
print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
monitor,patience),file=sys.stderr)
break
net.load_state_dict(torch.load(ckpt_path))
dfhistory = pd.DataFrame(history)
================================================================================2023-08-02 11:48:14
Epoch 1 / 20
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 661.90it/s, train_acc=0.654, train_loss=0.65]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1108.37it/s, val_acc=0.698, val_loss=0.584]
<<<<<< reach best val_acc : 0.6983240246772766 >>>>>>
================================================================================2023-08-02 11:48:14
Epoch 2 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 761.63it/s, train_acc=0.718, train_loss=0.574]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 918.43it/s, val_acc=0.749, val_loss=0.482]
<<<<<< reach best val_acc : 0.748603343963623 >>>>>>
================================================================================2023-08-02 11:48:14
Epoch 3 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 816.67it/s, train_acc=0.788, train_loss=0.513]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1031.02it/s, val_acc=0.765, val_loss=0.478]
<<<<<< reach best val_acc : 0.7653631567955017 >>>>>>
================================================================================2023-08-02 11:48:14
Epoch 4 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 783.66it/s, train_acc=0.795, train_loss=0.508]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1012.42it/s, val_acc=0.777, val_loss=0.416]
<<<<<< reach best val_acc : 0.7765362858772278 >>>>>>
================================================================================2023-08-02 11:48:14
Epoch 5 / 20
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 792.31it/s, train_acc=0.778, train_loss=0.5]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 849.80it/s, val_acc=0.793, val_loss=0.454]
<<<<<< reach best val_acc : 0.7932960987091064 >>>>>>
================================================================================2023-08-02 11:48:14
Epoch 6 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 816.62it/s, train_acc=0.792, train_loss=0.466]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1071.58it/s, val_acc=0.793, val_loss=0.48]
================================================================================2023-08-02 11:48:15
Epoch 7 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 799.33it/s, train_acc=0.791, train_loss=0.486]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1063.58it/s, val_acc=0.782, val_loss=0.441]
================================================================================2023-08-02 11:48:15
Epoch 8 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 724.34it/s, train_acc=0.789, train_loss=0.466]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1211.66it/s, val_acc=0.81, val_loss=0.433]
<<<<<< reach best val_acc : 0.8100558519363403 >>>>>>
================================================================================2023-08-02 11:48:15
Epoch 9 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 742.96it/s, train_acc=0.787, train_loss=0.473]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 891.92it/s, val_acc=0.782, val_loss=0.438]
================================================================================2023-08-02 11:48:15
Epoch 10 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 780.30it/s, train_acc=0.812, train_loss=0.463]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1006.84it/s, val_acc=0.782, val_loss=0.418]
================================================================================2023-08-02 11:48:15
Epoch 11 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 823.80it/s, train_acc=0.788, train_loss=0.466]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1130.61it/s, val_acc=0.782, val_loss=0.477]
================================================================================2023-08-02 11:48:15
Epoch 12 / 20
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 803.21it/s, train_acc=0.791, train_loss=0.463]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1183.49it/s, val_acc=0.777, val_loss=0.468]
================================================================================2023-08-02 11:48:15
Epoch 13 / 20
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 817.11it/s, train_acc=0.795, train_loss=0.46]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 1159.69it/s, val_acc=0.788, val_loss=0.469]
<<<<<< val_acc without improvement in 5 epoch, early stopping >>>>>>
四,评估模型
我们首先评估一下模型在训练集和验证集上的效果。
dfhistory
train_loss | train_acc | val_loss | val_acc | epoch | |
---|---|---|---|---|---|
0 | 0.653146 | 0.662921 | 0.589680 | 0.681564 | 1 |
1 | 0.595200 | 0.700843 | 0.523722 | 0.759777 | 2 |
2 | 0.531601 | 0.758427 | 0.493227 | 0.765363 | 3 |
3 | 0.540394 | 0.766854 | 0.493356 | 0.720670 | 4 |
4 | 0.511390 | 0.793539 | 0.512084 | 0.754190 | 5 |
5 | 0.512636 | 0.787921 | 0.465292 | 0.776536 | 6 |
6 | 0.482334 | 0.785112 | 0.457128 | 0.776536 | 7 |
7 | 0.494457 | 0.783708 | 0.468475 | 0.787709 | 8 |
8 | 0.511432 | 0.785112 | 0.441753 | 0.776536 | 9 |
9 | 0.496386 | 0.765449 | 0.462543 | 0.776536 | 10 |
10 | 0.480010 | 0.782303 | 0.435424 | 0.810056 | 11 |
11 | 0.468407 | 0.789326 | 0.408479 | 0.798883 | 12 |
12 | 0.465568 | 0.792135 | 0.403323 | 0.815642 | 13 |
13 | 0.472104 | 0.778090 | 0.476357 | 0.770950 | 14 |
14 | 0.473596 | 0.793539 | 0.447321 | 0.798883 | 15 |
15 | 0.444280 | 0.793539 | 0.405534 | 0.793296 | 16 |
16 | 0.460128 | 0.794944 | 0.428926 | 0.787709 | 17 |
17 | 0.440345 | 0.806180 | 0.435658 | 0.776536 | 18 |
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
def plot_metric(dfhistory, metric):
train_metrics = dfhistory["train_"+metric]
val_metrics = dfhistory['val_'+metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation '+ metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_"+metric, 'val_'+metric])
plt.show()
plot_metric(dfhistory,"loss")
plot_metric(dfhistory,"acc")
五,使用模型
#预测概率
y_pred_probs = torch.sigmoid(net(torch.tensor(x_test[0:10]).float())).data
y_pred_probs
tensor([[0.0487],
[0.5014],
[0.2651],
[0.9025],
[0.4703],
[0.9044],
[0.0710],
[0.9568],
[0.4578],
[0.1043]])
#预测类别
y_pred = torch.where(y_pred_probs>0.5,
torch.ones_like(y_pred_probs),torch.zeros_like(y_pred_probs))
y_pred
tensor([[0.],
[1.],
[0.],
[1.],
[0.],
[1.],
[0.],
[1.],
[0.],
[0.]])
六,保存模型
Pytorch 有两种保存模型的方式,都是通过调用pickle序列化方法实现的。
第一种方法只保存模型参数。
第二种方法保存完整模型。
推荐使用第一种,第二种方法可能在切换设备和目录的时候出现各种问题。
1,保存模型参数(推荐)
print(net.state_dict().keys())
odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'linear3.weight', 'linear3.bias'])
# 保存模型参数
torch.save(net.state_dict(), "./data/net_parameter.pt")
net_clone = create_net()
net_clone.load_state_dict(torch.load("./data/net_parameter.pt"))
torch.sigmoid(net_clone.forward(torch.tensor(x_test[0:10]).float())).data
tensor([[0.0487],
[0.5014],
[0.2651],
[0.9025],
[0.4703],
[0.9044],
[0.0710],
[0.9568],
[0.4578],
[0.1043]])
2,保存完整模型(不推荐)
torch.save(net, './data/net_model.pt')
net_loaded = torch.load('./data/net_model.pt')
torch.sigmoid(net_loaded(torch.tensor(x_test[0:10]).float())).data
tensor([[0.0487],
[0.5014],
[0.2651],
[0.9025],
[0.4703],
[0.9044],
[0.0710],
[0.9568],
[0.4578],
[0.1043]])