我的神经网络类python代码编程习惯

自定义数据集

自定义数据集至少要重写__init____len____getitem__三个方法,init中定义数据路径,最好能把数据读进内存;len中定义有多少个训练样本;getitem尽量只从内存读,避免读磁盘,若数据太大,可以维持一个固定大小的内存池,偶尔从磁盘读。

若getitem包含运算,则设置num_workers>0,并行读取
torch.backends.cudnn.benchmark = True 开启可以加速卷积神经网络运算。

Dataloader示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torch.utils.data import Dataloader
from tqdm import tqdm

dataloader = Dataloader(dataset, batchsize=8, shuffle=True)

for i in range(epoch):
with tqdm(total=len(dataloader)) as t:
for idx, (batch_x, batch_y) in enumerate(dataloader):
# pre_y = model(batch_x)
# loss= loss_fn(pre_y, batch_y)
t.set_description(desc="Epoch %i:"%i)
t.set_postfix(steps=idx, loss=loss.data.item())
t.update(1)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()

其他注意

  1. 代码文件中标注

    1
    __author__ = 'kly'
  2. 配置参数使用argparse,并在运行时打印配置信息,以便日志中保存:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import argparse
    def make_parser():
    parser = argparse.ArgumentParser("train parameter")
    parser.add_argument("-w", "--weight_path", default="./weights", type=str, help="model save path")
    parser.add_argument("--log_path", default="./logs", type=str, help="tensorboard log save path")
    parser.add_argument("--class_nums", default=45, type=int, help="how many classes do you have")
    parser.add_argument("--epoch_nums", default=5, type=int, help="train epoch")
    parser.add_argument("--batch_size", default=16, type=int, help="batch size")
    parser.add_argument("--lr", default=0.0001, type=float, help="init lr")
    parser.add_argument("--tsize", default=256, type=int, help="train img size = tsize * tsize")
    return parser

    # main函数中
    args = make_parser().parse_args()
    print("----------------------------------------------------------------")
    for key in args.__dict__:
    print(key, end=' = ')
    print(args.__dict__[key])
    print("----------------------------------------------------------------")
  1. 对于有后续完善空间的部分要标注#TODO

  2. 适当写警告和报错语句

    1
    2
    3
    4
    5
    6
    try:
    assert np.isfinite(score)
    except AssertionError as e:
    raise ValueError('score is NaN or infinite') from e

    raise ValueError('the function is not supported now')
  3. 固定所有随机数种子。

  4. 每次训练要将训练log输出到文件,保存在train.log中,并且log中需打印出本次实验的参数配置。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import sys
    import os

    class Logger(object):
    def __init__(self, filename='default.log', stream=sys.stdout):
    self.terminal = stream
    self.log = open(filename, 'w')

    def write(self, message):
    self.terminal.write(message)
    self.log.write(message)

    def flush(self):
    pass

    sys.stdout = Logger(os.path.join(args.log_path,time.strftime('train21-%m-%d-%H%M.log',time.localtime(time.time()))), sys.stdout)
    # sys.stderr = Logger('a.log_file', sys.stderr)
  5. 打印训练过程时,记得加上时间

    1
    2
    import time
    print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
  6. 使用Tensorboard进行可视化

    pytorch使用Tensorboard示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    from torch.utils import tensorboard

    writer = tensorboard.SummaryWriter('../logs/')
    print('tensorboard initialized')

    init_image = torch.zeros((1,3,224,224), device=device)
    writer.add_graph(model, init_image)

    writer.add_scalar('train_loss', loss, epoch+1)
    writer.add_scalar('val_acc', acc, epoch+1)
    writer.add_scalar('lr', lr, epoch+1)
  1. 一些需要反复使用的调试语句,可以使用logger输出

    给logger设置是告诉它要记录哪些级别的日志,给handler设是告诉它要输出哪些级别的日志,相当于进行了两次过滤。这样的好处在于,当我们有多个日志去向时,比如既保存到文件,又输出到控制台,就可以分别给他们设置不同的级别;logger 的级别是先过滤的,所以被 logger 过滤的日志 handler 也是无法记录的,这样就可以只改 logger 的级别而影响所有输出。两者结合可以更方便地管理日志记录的级别。

    logging.FileHandler -> 文件输出

    logging.StreamHandler() # 控制台输出

    logging.handlers.RotatingFileHandler -> 按照大小自动分割日志文件,一旦达到指定的大小重新生成文件

    logging.handlers.TimedRotatingFileHandler -> 按照时间自动分割日志文件

    logger.debug(‘debug级别,一般用来打印一些调试信息,级别最低’)

    logger.info(‘info级别,一般用来打印一些正常的操作信息’)

    logger.warning(‘waring级别,一般用来打印警告信息’)

    logger.error(‘error级别,一般用来打印一些错误信息’)

    logger.critical(‘critical级别,一般用来打印一些致命的错误信息,等级最高’)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import logging
from logging import handlers

logger = logging.getLogger('train')
logger.setLevel(level=logging.DEBUG) # 设置打印级别
formatter = logging.Formatter('%(asctime)s: %(message)s') # 设置打印格式

stream_handler = logging.StreamHandler() # 控制台输出
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(formatter)

file_handler = logging.FileHandler('train1.log', encoding='utf-8')
file_handler.setLevel(level=logging.INFO)
file_handler.setFormatter(formatter)

logger.addHandler(stream_handler)
logger.addHandler(file_handler)

logger.info('info级别,一般用来打印一些正常的操作信息')

time_rotating_file_handler = handlers.TimedRotatingFileHandler(filename='rotating_test.log', when='D',encoding='utf-8')
time_rotating_file_handler.setLevel(logging.INFO)
time_rotating_file_handler.setFormatter(formatter)
logger.addHandler(time_rotating_file_handler)
  1. 代码需要注意包含:断点续训、保存模型、加载模型进行测试这几部分。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # pytorch

    # 加载模型
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location=device))

    # 保存最优模型
    if val_accurate > best_acc:
    best_acc = val_accurate
    torch.save(net.state_dict(), save_path)
  1. 项目开源时要有requirements.txt文件,用于记录所有依赖包及其精确的版本号。

    主要的用法如下

    1
    2
    pip freeze > requirements.txt  # 生成requirements.txt
    pip install -r requirements.txt # 从requirements.txt安装依赖

    文件中支持的写法

    1
    2
    3
    4
    5
    -r base.txt  # base.txt下面的所有包
    pypinyin==0.12.0 # 指定版本(最日常的写法)
    django-querycount>=0.5.0 # 大于某个版本
    django-debug-toolbar>=1.3.1,<=1.3.3 # 版本范围
    ipython # 默认(存在不替换,不存在安装最新版)

请我喝杯咖啡吧~

支付宝
微信