Flops Profiler
Measures the parameters, latency, and floating-point operations of PyTorch model.
- class flops_profiler.profiler.FlopsProfiler(model, ds_engine=None)[source]
Bases:
objectMeasures the latency, number of estimated floating-point operations and parameters of each module in a PyTorch model.
The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input.
Here is an example for usage in a typical training workflow:
model = Model() prof = FlopsProfiler(model) for step, batch in enumerate(data_loader): if step == profile_step: prof.start_profile() loss = model(batch) if step == profile_step: flops = prof._get_total_flops(as_string=True) params = prof._get_total_params(as_string=True) prof.print_model_profile(profile_step=profile_step) prof.end_profile() loss.backward() optimizer.step()
To profile a trained model in inference, use the get_model_profile API.
- Parameters:
object (torch.nn.Module) – The PyTorch model to profile.
- end_profile()[source]
Ends profiling.
The added attributes and handles are removed recursively on all the modules.
- print_model_aggregated_profile(module_depth=-1, top_modules=1)[source]
Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters at depth module_depth.
- Parameters:
module_depth (int, optional) – the depth of the modules to show. Defaults to -1 (the innermost modules).
top_modules (int, optional) – the number of top modules to show. Defaults to 1.
- print_model_profile(profile_step=1, module_depth=-1, top_modules=1, detailed=True, output_file=None)[source]
Prints the model graph with the measured profile attached to each module.
- Parameters:
profile_step (int, optional) – The global training step at which to profile. Note that warm up steps are needed for accurate time measurement.
module_depth (int, optional) – The depth of the model to which to print the aggregated module information. When set to -1, it prints information from the top to the innermost modules (the maximum depth).
top_modules (int, optional) – Limits the aggregated profile output to the number of top modules specified.
detailed (bool, optional) – Whether to print the detailed model profile.
output_file (str, optional) – Path to the output file. If None, the profiler prints to stdout.
- flops_profiler.profiler.get_model_profile(model, input_shape=None, args=[], kwargs={}, print_profile=True, detailed=True, module_depth=-1, top_modules=10, warm_up=3, as_string=True, output_file=None, ignore_modules=None, func_name='forward')[source]
Returns the total floating-point operations, MACs, and parameters of a model.
Example:
model = torchvision.models.alexnet() batch_size = 256 flops, macs, params = get_model_profile(model=model, input_shape=(batch_size, 3, 224, 224)))
- Parameters:
model ([torch.nn.Module]) – the PyTorch model to be profiled.
input_shape (tuple) – input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
args (list) – list of positional arguments to the model.
kwargs (dict) – dictionary of keyword arguments to the model.
print_profile (bool, optional) – whether to print the model profile. Defaults to True.
detailed (bool, optional) – whether to print the detailed model profile. Defaults to True.
module_depth (int, optional) – the depth into the nested modules. Defaults to -1 (the inner most modules).
top_modules (int, optional) – the number of top modules to print in the aggregated profile. Defaults to 3.
warm_up (int, optional) – the number of warm-up steps before measuring the latency of each module. Defaults to 1.
as_string (bool, optional) – whether to print the output as string. Defaults to True.
output_file (str, optional) – path to the output file. If None, the profiler prints to stdout.
ignore_modules ([type], optional) – the list of modules to ignore during profiling. Defaults to None.
- Returns:
The number of floating-point operations, multiply-accumulate operations (MACs), and parameters in the model.