from torchvision.models import wide_resnet50_2
import torch
wide_resnet = wide_resnet50_2(weights=None)
input_shape = (1, 3, 244, 244)
input_tensor = torch.randn(input_shape)Profiling libraries in one spot.
Keep in mind that automatic FLOP measurements are approximations and can be imprecise, particularly when faced with non-standard layers. Custom kernels, for example, will be outright skipped if no formulas or specific values are manually registered for them. Similarly, unsupported operations will not contribute to the final estimation; sparse tensors may yield the same FLOP counts as their dense counterparts. And so on…
That being said, those numbers are still accurate enough in a lot of scenarios and can be quite handy, especially when comparing similar architectures (uumh actually, algorithms running on those architectures).
You can check your intuition by asking yourself: What is the FLOP count for an embedding layer? Backward pass? What about fine-tuning a model with the embedding layer being frozen? A de-embedding layer?
Option 1: PyTorch’s inbuilt FLOP counter
As of now this module is actually undocumented. More info in the resources part. This code takes into consideration both forward and backward pass.
from torch.utils.flop_counter import FlopCounterMode
flop_counter = FlopCounterMode(display=False, depth=None)
with flop_counter:
wide_resnet(input_tensor).sum().backward()
total_flops_one_fwd_bwd: int = flop_counter.get_total_flops()
print(f"Total GigaFLOPs one forward-backward: {total_flops_one_fwd_bwd / 1e9}")Total GigaFLOPs one forward-backward: 86.000790528
Option 2: torchinfo
This time only the forward pass.
from torchinfo import summary
model_stats = summary(wide_resnet, input_size=input_shape, verbose=0)
model_statsprint model_stats reasonably
from IPython.display import Markdown, display
summary_str = str(model_stats)
display(Markdown("```\n" + summary_str[:500] + "\n```"))
print("...")
display(Markdown("```\n" + summary_str[-645:] + "\n```"))==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ResNet [1, 1000] --
├─Conv2d: 1-1 [1, 64, 122, 122] 9,408
├─BatchNorm2d: 1-2 [1, 64, 122, 122] 128
├─ReLU: 1-3
...
├─AdaptiveAvgPool2d: 1-9 [1, 2048, 1, 1] --
├─Linear: 1-10 [1, 1000] 2,049,000
==========================================================================================
Total params: 68,883,240
Trainable params: 68,883,240
Non-trainable params: 0
Total mult-adds (G): 14.38
==========================================================================================
Input size (MB): 0.71
Forward/backward pass size (MB): 282.22
Params size (MB): 275.53
Estimated Total Size (MB): 558.47
==========================================================================================
Total mult-adds (G): 14.38 is of intrest to us (Giga-MACs = 1B (10⁹) MACs). So 28.76 GFLOPs.
Given that it is often assumed that backward pass has about twice the amount of FLOPs as the forward one, this checks out. 28.76 * 3 = 86.28, which is close to PyTorch’s FlopCounterMode output.
- Resources:
Option3: deepspeed
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator
# with get_accelerator().device(0):
flops, macs, params = get_model_profile(
model=wide_resnet,
input_shape=input_shape,
args=None,
kwargs=None,
print_profile=True,
detailed=True,
module_depth=-1,
top_modules=1,
warm_up=10,
as_string=True,
output_file=None,
ignore_modules=None,
)
print(flops, macs, params)
print("Params:", params)
print("GMACs:", macs)
print("GFLOPs:", flops)Params: 68.88 M
GMACs: 14.38 GMACs
GFLOPs: 28.81 G
Besides these values, deepspeed outputs quite a detailed profiling report.
- Resources:
Option4: fvcore
from fvcore.nn import FlopCountAnalysis
flops = FlopCountAnalysis(wide_resnet, input_tensor)
print("FLOPs: ", flops.total())
print(f"GFLOPs: {flops.total() / 1e9}")Unsupported operator aten::add_ encountered 69 time(s)
Unsupported operator aten::max_pool2d encountered 1 time(s)
FLOPs: 14468464384
GFLOPs: 14.468464384
Well, interestingly, we got almost exactly as many FLOPs as MACs with torchinfo and deepspeed.
Encounter counter does not disappoint.
Option5: ptflops
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(
wide_resnet, input_shape[1:], as_strings=False, print_per_layer_stat=False, backend='pytorch'
)
print("Params:", params)
print("GMACs:", macs / 1e9)
print("GFLOPs:", (macs / 1e9) * 2)Params: 68883240
GMACS: 14.44918756
GFLOPS: 28.89837512
- Resources:
Option6: flopth
from flopth import flopth
flops, params = flopth(wide_resnet, in_size=input_shape[1:])
print("Params:", params)
print("GFLOPs:", flops)Params: 68.8832M
GFLOPs: 14.4242G
- Resources:
Option7: calflops
from calflops import calculate_flops
flops, macs, params = calculate_flops(
model=wide_resnet, input_shape=input_shape, output_as_string=True, output_precision=4
)
print("Params:", params)
print("GMACs:", macs[:7])
print("GFLOPs:", flops[:7])Params: 68.8832 M
GMACs: 14.3801
GFLOPs: 28.8124
- Resources:
Option8: thop
from thop import profile
macs, params = profile(wide_resnet, inputs=(input_tensor, ))
print("Params:", params)
print("GMACs:", macs / 1e9)Params: 68883240.0
GMACs: 43.352484096
43 is quite a number when compared to other profilers.
- Resources:
If you’re looking for a dedicated third-party library, I would choose between deepspeed, calflops (Transformers), and fvcore (CV).