Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional, Sequence | |
| from mmengine.hooks import Hook | |
| from mmengine.model import is_model_wrapper | |
| from mmseg.registry import HOOKS | |
| import torch | |
| import json | |
| import os | |
| def group_subnets_by_flops(data, flops_step=10): | |
| sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])} | |
| candidate_idx = [] | |
| grouped_cands = [] | |
| last_flops = 0 | |
| for cfg_id, flops in sorted_data.items(): | |
| # flops, _ = values | |
| flops = flops / 1e9 | |
| if abs(last_flops - flops) > flops_step: | |
| if len(candidate_idx) > 0: | |
| grouped_cands.append(candidate_idx) | |
| candidate_idx = [int(cfg_id)] | |
| last_flops = flops | |
| else: | |
| candidate_idx.append(int(cfg_id)) | |
| if len(candidate_idx) > 0: | |
| grouped_cands.append(candidate_idx) | |
| return grouped_cands | |
| def initialize_model_stitching_layer(model, dataiter): | |
| images = [] | |
| total_samples = 50 | |
| while len(images) < total_samples: | |
| item = next(dataiter) | |
| data = model.data_preprocessor(item, True) | |
| images.extend(data['inputs']) | |
| images = torch.stack(images, dim=0) | |
| samples = images.cuda() | |
| model.backbone.initialize_stitching_weights(samples) | |
| class SNNetHook(Hook): | |
| """Docstring for NewHook. | |
| """ | |
| def before_train(self, runner) -> None: | |
| if is_model_wrapper(runner.model): | |
| model = runner.model.module | |
| else: | |
| model = runner.model | |
| if not runner._resume: | |
| initialize_model_stitching_layer(model, runner.train_loop.dataloader_iterator) | |
| # cfg = Config.fromfile(runner._cfg_file) | |
| cfg_name = runner.cfg.filename.split('/')[-1].split('.')[0] | |
| with open(os.path.join('./model_flops', f'snnet_flops_{cfg_name}.json'), 'r') as f: | |
| flops_params = json.load(f) | |
| flops_step = 10 | |
| grouped_subnet = group_subnets_by_flops(flops_params, flops_step) | |
| model.backbone.flops_grouped_cfgs = grouped_subnet | |