Fuses a list of modules into a single module.
Fuses only the following sequence of modules:
conv, bn;
conv, bn, relu;
conv, relu;
conv, bn, add;
conv, bn, add, relu;
conv, add;
conv, add, relu;
linear, bn;
linear, bn, relu;
linear, relu;
linear, bn, add;
linear, bn, add, relu;
linear, add;
linear, add, relu.
For these sequences, the first element in the output module list performs the fused operation. The rest of the elements are set to nn.Identity()
Parameters:
model – Model containing the modules to be fused
modules_to_fuse – list of list of module names to fuse. Can also be a list of strings if there is only a single list of modules to fuse.
inplace – bool specifying if fusion happens in place on the model, by default a new model is returned
fuser_func – Function that takes in a list of modules and outputs a list of fused modules of the same length. For example, fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] Defaults to torch.ao.quantization.fuse_known_modules
fuse_custom_config_dict – custom configuration for fusion
Returns: model with fused modules. A new copy is created if inplace=True.
Examples: