今天看完了EEGNet的论文准备搭建一下EEGNet的网络,然后想到之前看过网络配置文件的内容,然后想着以后开发自己的网络的能够规范和方便,所以就学习一下,并在这里记录一下,方便以后查阅。
config配置文件原理及使用
config配置文件
config代码
a.yaml
DATA:
BATCH_SIZE: 512
MODEL:
TRANS:
EMBED_DIM: 768
config.py
from yacs.config import CfgNode as CN
import yaml
# 设置默认参数
_C = CN()
_C.DATA = CN()
_C.DATA.DATASET = 'cifar10'
_C.DATA.BATCH_SIZE = 128
_C.MODEL = CN()
_C.MODEL.NUM_CLASSES = 10
_C.MODEL.TRANS = CN()
_C.MODEL.TRANS.EMBED_DIM = 96
_C.MODEL.TRANS.DEPTHS = [2, 2, 6, 2]
_C.MODEL.TRANS.QKV_BIAS = False
# 通过yaml更新参数
def _update_config_from_file(config, cfg_file):
config.defrost()
config.merge_from_file(cfg_file) # .yaml
# 通过argparser.ArgumentParser更新参数
def update_config(config, args):
if args.cfg:
_update(config, args.cfg)
if args.dataset:
config.DATA.DATASET = args.datasert
if args.batch_size:
config.DATA.BATCH_SIZE = args.batch_size
return config
def get_config(cfg_file=None):
config = _C.clone()
if cfg_file:
_update_config_from_file(config, cfg_file)
return config
def main():
cfg = get_config('./a.yaml')
print(cfg)
if __name__ == "__main__":
main()
输出:
argparse.py
import argparse
from config import get_config
from config import update_config
def get_argument():
parser = argparse.ArgumentParser('ViT')
parser.add_argument('-cfg', type=str, default=None)
parser.add_argument('-dataset', type=str, default=None)
parser.add_argument('-batch_size', type=str, default=None)
arguments = parser.parse_args()
return arguments
def main():
cfg = get_config()
print(cfg)
print('-----------------')
cfg = get_config('./a.yaml')
print(cfg)
print('-----------------')
args = get_argument()
cfg = update_config(cfg, args)
print(cfg)
if __name__ == "__main__":
main()
输出:
config配置文件的使用