anomalib代码解析之三:训练过程

news/2024/11/15 6:01:54/

咱们吃个回头草吧

上面的图中,第55行,藏有玄机。前面我们没详细讲。就是这行,指定了,cfa算法,怎么训练的,算法实现细节都在这里。

那我们就得看get_model函数了:

def get_model(config: DictConfig | ListConfig) -> AnomalyModule:"""Load model from the configuration file.Works only when the convention for model naming is followed.The convention for writing model classes is`anomalib.models.<model_name>.lightning_model.<ModelName>Lightning``anomalib.models.stfpm.lightning_model.StfpmLightning`Args:config (DictConfig | ListConfig): Config.yaml loaded using OmegaConfRaises:ValueError: If unsupported model is passedReturns:AnomalyModule: Anomaly Model"""logger.info("Loading the model.")model_list: list[str] = ["cfa","cflow","csflow","dfkde","dfm","draem","fastflow","ganomaly","padim","patchcore","reverse_distillation","rkde","stfpm",]model: AnomalyModuleif config.model.name in model_list:module = import_module(f"anomalib.models.{config.model.name}")model = getattr(module, f"{_snake_to_pascal_case(config.model.name)}Lightning")(config)print("---------------getattr")print(getattr(module, f"{_snake_to_pascal_case(config.model.name)}Lightning"))print("---------------getattr-end")else:raise ValueError(f"Unknown model {config.model.name}!")if "init_weights" in config.keys() and config.init_weights:model.load_state_dict(load(os.path.join(config.project.path, config.init_weights))["state_dict"], strict=False)return model

 这里面,最关键的是,

一,关于module

module = import_module(f"anomalib.models.{config.model.name}")

因为在config.yaml中,已经指定,

anomalib.models.{config.model.name}就是anomalib.models.cfa


 

那么,本质就是执行的下面的代码

module = import_module(anomalib.models.cfa)

 记住上面这行代码,重要!

于是,module就是下面这个东西:

<module 'anomalib.models.cfa' from 'D:\\BaiduNetdiskDownload\\anomalib\\anomalib
-main\\src\\anomalib\\models\\cfa\\__init__.py'>

二,关于model

model = getattr(module, f"{_snake_to_pascal_case(config.model.name)}Lightning")(config)

其实就是执行,

model=getattr(module,CfaLightning)(config)

换句话说,就是

model=CfaLightning(config)

三、关于训练数据datamodule

 

 关键是看看那个54行的

datamodule = get_datamodule(config)
def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:"""Get Anomaly Datamodule.Args:config (DictConfig | ListConfig): Configuration of the anomaly model.Returns:PyTorch Lightning DataModule"""logger.info("Loading the datamodule")datamodule: AnomalibDataModule# convert center crop to tuplecenter_crop = config.dataset.get("center_crop")if center_crop is not None:center_crop = (center_crop[0], center_crop[1])if config.dataset.format.lower() == "mvtec":datamodule = MVTec(root=config.dataset.path,category=config.dataset.category,image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),center_crop=center_crop,normalization=config.dataset.normalization,train_batch_size=config.dataset.train_batch_size,eval_batch_size=config.dataset.eval_batch_size,num_workers=config.dataset.num_workers,task=config.dataset.task,transform_config_train=config.dataset.transform_config.train,transform_config_eval=config.dataset.transform_config.eval,test_split_mode=config.dataset.test_split_mode,test_split_ratio=config.dataset.test_split_ratio,val_split_mode=config.dataset.val_split_mode,val_split_ratio=config.dataset.val_split_ratio,)elif config.dataset.format.lower() == "mvtec_3d":datamodule = MVTec3D(root=config.dataset.path,category=config.dataset.category,image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),。。。。。。。。。。return datamodule

这个datamodule,是什么类型的呢?

AnomalibDataModule类型的!而AnomalibDataModule的父类,是LightningDataModule类型
class AnomalibDataModule(LightningDataModule, ABC):"""Base Anomalib data module.Args:train_batch_size (int): Batch size used by the train dataloader.test_batch_size (int): Batch size used by the val and test dataloaders.num_workers (int): Number of workers used by the train, val and test dataloaders.test_split_mode (Optional[TestSplitMode], optional): Determines how the test split is obtained.Options: [none, from_dir, synthetic]test_split_ratio (float): Fraction of the train images held out for testing.val_split_mode (ValSplitMode): Determines how the validation split is obtained. Options: [none, same_as_test,from_test, synthetic]val_split_ratio (float): Fraction of the train or test images held our for validation.seed (int | None, optional): Seed used during random subset splitting."""def __init__(self,train_batch_size: int,eval_batch_size: int,num_workers: int,val_split_mode: ValSplitMode,val_split_ratio: float,test_split_mode: TestSplitMode | None = None,test_split_ratio: float | None = None,seed: int | None = None,) -> None:super().__init__()self.train_batch_size = train_batch_sizeself.eval_batch_size = eval_batch_sizeself.num_workers = num_workersself.test_split_mode = test_split_modeself.test_split_ratio = test_split_ratioself.val_split_mode = val_split_modeself.val_split_ratio = val_split_ratioself.seed = seedself.train_data: AnomalibDatasetself.val_data: AnomalibDatasetself.test_data: AnomalibDatasetself._samples: DataFrame | None = Nonedef setup(self, stage: str | None = None) -> None:"""Setup train, validation and test data.Args:stage: str | None:  Train/Val/Test stages. (Default value = None)"""if not self.is_setup:self._setup(stage)assert self.is_setupdef _setup(self, _stage: str | None = None) -> None:"""Set up the datasets and perform dynamic subset splitting.This method may be overridden in subclass for custom splitting behaviour.Note: The stage argument is not used here. This is because, for a given instance of an AnomalibDataModulesubclass, all three subsets are created at the first call of setup(). This is to accommodate the subsetsplitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, andthe test set must therefore be created as early as the `fit` stage."""assert self.train_data is not Noneassert self.test_data is not Noneself.train_data.setup()self.test_data.setup()self._create_test_split()self._create_val_split()def _create_test_split(self) -> None:"""Obtain the test set based on the settings in the config."""if self.test_data.has_normal:# split the test data into normal and anomalous so these can be processed separatelynormal_test_data, self.test_data = split_by_label(self.test_data)elif self.test_split_mode != TestSplitMode.NONE:# when the user did not provide any normal images for testing, we sample some from the training set,# except when the user explicitly requested no test splitting.logger.info("No normal test images found. Sampling from training set using a split ratio of %d",self.test_split_ratio,)if self.test_split_ratio is not None:self.train_data, normal_test_data = random_split(self.train_data, self.test_split_ratio, seed=self.seed)if self.test_split_mode == TestSplitMode.FROM_DIR:self.test_data += normal_test_dataelif self.test_split_mode == TestSplitMode.SYNTHETIC:self.test_data = SyntheticAnomalyDataset.from_dataset(normal_test_data)elif self.test_split_mode != TestSplitMode.NONE:raise ValueError(f"Unsupported Test Split Mode: {self.test_split_mode}")def _create_val_split(self) -> None:"""Obtain the validation set based on the settings in the config."""if self.val_split_mode == ValSplitMode.FROM_TEST:# randomly sampled from test setself.test_data, self.val_data = random_split(self.test_data, self.val_split_ratio, label_aware=True, seed=self.seed)elif self.val_split_mode == ValSplitMode.SAME_AS_TEST:# equal to test setself.val_data = self.test_dataelif self.val_split_mode == ValSplitMode.SYNTHETIC:# converted from random training sampleself.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed)self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data)elif self.val_split_mode != ValSplitMode.NONE:raise ValueError(f"Unknown validation split mode: {self.val_split_mode}")@propertydef is_setup(self) -> bool:"""Checks if setup() has been called.At least one of [train_data, val_data, test_data] should be setup."""_is_setup: bool = Falsefor data in ("train_data", "val_data", "test_data"):if hasattr(self, data):if getattr(self, data).is_setup:_is_setup = Truereturn _is_setupdef train_dataloader(self) -> TRAIN_DATALOADERS:"""Get train dataloader."""return DataLoader(dataset=self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers)def val_dataloader(self) -> EVAL_DATALOADERS:"""Get validation dataloader."""return DataLoader(dataset=self.val_data,shuffle=False,batch_size=self.eval_batch_size,num_workers=self.num_workers,collate_fn=collate_fn,)def test_dataloader(self) -> EVAL_DATALOADERS:"""Get test dataloader."""return DataLoader(dataset=self.test_data,shuffle=False,batch_size=self.eval_batch_size,num_workers=self.num_workers,collate_fn=collate_fn,)


http://www.ppmy.cn/news/38219.html

相关文章

Python矩阵分解之QR分解

文章目录QR和RQ分解其他函数QR和RQ分解 记AAA为方阵&#xff0c;P,QP, QP,Q分别为正交单位阵和上三角阵&#xff0c;则形如AQRAQRAQR的分解为QR分解&#xff1b;形如ARQARQARQ的分解为RQ分解。 在scipy.linalg中&#xff0c;为二者提供了相同的参数&#xff0c;除了待分解矩阵…

SPSS27破解安装后,出现应用程序无法正常启动(0xc000007b)

破解完SPSS 27软件后&#xff0c;点击图标出现下图错误 可以尝试以下方法&#xff1a; 1. 在安装目录下找到VC开头的文件夹 2. 点击此软件进行修复 若修复完成&#xff0c;重新启动SPSS软件即可。 3. 若提示错误&#xff0c;显示如下界面&#xff0c;进行下面的方法 4. 下…

用户行为分析zhi应用分析模型

&#xff08;1&#xff09;基于AARRR漏斗模型分析用户行为 本文通过常用的电商数据分析指标&#xff0c;采用AARRR漏斗模型拆解用户进入APP后的每一步行为。AARRR模型是根据用户使用产品全流程的不同阶段进行划分的&#xff0c;针对每一环节的用户流失情况分析出不同环节的优化…

斐波拉契数列,有人买了一对小兔子,已知小兔子一个月后长成大兔子,大兔子每个月生一对小兔子,问:两年(24个月)之后,他一共有几对兔子。

[01]斐波拉契数列&#xff0c;有人买了一对小兔子&#xff0c;已知小兔子一个月后长成大兔子&#xff0c;大兔子每个月生一对小兔子&#xff0c;问:两年(24个月)之后&#xff0c;他一共有几对兔子。 第i月份大兔子小兔子总兔子1011210131124213532565387851381382192113341034…

7_linux进程管理

7_linux进程管理 文章目录7_linux进程管理一、进程定义二、进程查看命令三、进程优先级四、进程前后台调用五、进程信息号六、systemd守护进程七、系统中的登陆审计一、进程定义 进程的定义 程序是静态的代码文件进程是指程序运行时的形态进程是程序的一个副本进程是有生命周期…

不同路径数(DFS+哈希)

给定一个 nm 的二维矩阵&#xff0c;其中的每个元素都是一个 [1,9] 之间的正整数。 从矩阵中的任意位置出发&#xff0c;每次可以沿上下左右四个方向前进一步&#xff0c;走过的位置可以重复走。 走了 k 次后&#xff0c;经过的元素会构成一个 (k1) 位数。 请求出一共可以走…

易语言支持库配置闪退丨支持库配置崩溃_易语言打开支持库配置就闪退怎么办?

易语言支持库配置闪退打不开怎么办&#xff1f; 易语言支持库配置闪退解决方法丨支持库配置崩溃_易语言打开支持易语言支持库配置闪退丨支持库配置崩溃_易语言打开支持库配置就闪退怎么办? 很多人都遇到过 打开易语言支持库安装菜单报错退出的问题 今天教大家解决方法 我们…

腾讯云服务器端口怎么全打开?

腾讯云服务器端口怎么全开&#xff1f;云服务器CVM在安全组中设置开通&#xff0c;轻量应用服务器在防火墙中设置&#xff0c;腾讯云百科来详细说下腾讯云服务器端口全开放教程&#xff1a; 目录 腾讯云服务器端口全部开通教程 云服务器CVM端口全开放教程 轻量应用服务器开…