Different ways of counting PyTorch model FLOPs (library compilation)

pytorch
FLOPs
Author

Andre Mirończuk

Published

January 27, 2025

Profiling libraries in one spot.

Keep in mind that automatic FLOP measurements are approximations and can be imprecise, particularly when faced with non-standard layers. Custom kernels, for example, will be outright skipped if no formulas or specific values are manually registered for them. Similarly, unsupported operations will not contribute to the final estimation; sparse tensors may yield the same FLOP counts as their dense counterparts. And so on…

That being said, those numbers are still accurate enough in a lot of scenarios and can be quite handy, especially when comparing similar architectures (uumh actually, algorithms running on those architectures).

You can check your intuition by asking yourself: What is the FLOP count for an embedding layer? Backward pass? What about fine-tuning a model with the embedding layer being frozen? A de-embedding layer?

from torchvision.models import wide_resnet50_2
import torch

wide_resnet = wide_resnet50_2(weights=None)
input_shape = (1, 3, 244, 244)
input_tensor = torch.randn(input_shape)

Option 1: PyTorch’s inbuilt FLOP counter

As of now this module is actually undocumented. More info in the resources part. This code takes into consideration both forward and backward pass.

from torch.utils.flop_counter import FlopCounterMode

flop_counter = FlopCounterMode(display=False, depth=None)
with flop_counter:
  wide_resnet(input_tensor).sum().backward()
total_flops_one_fwd_bwd: int = flop_counter.get_total_flops()
print(f"Total GigaFLOPs one forward-backward: {total_flops_one_fwd_bwd / 1e9}")
Total GigaFLOPs one forward-backward: 86.000790528

Option 2: torchinfo

This time only the forward pass.

from torchinfo import summary
model_stats = summary(wide_resnet, input_size=input_shape, verbose=0)
model_stats
print model_stats reasonably
from IPython.display import Markdown, display
summary_str = str(model_stats)

display(Markdown("```\n" + summary_str[:500] + "\n```"))
print("...")
display(Markdown("```\n" + summary_str[-645:] + "\n```"))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 122, 122]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 122, 122]         128
├─ReLU: 1-3                  
...
├─AdaptiveAvgPool2d: 1-9                 [1, 2048, 1, 1]           --
├─Linear: 1-10                           [1, 1000]                 2,049,000
==========================================================================================
Total params: 68,883,240
Trainable params: 68,883,240
Non-trainable params: 0
Total mult-adds (G): 14.38
==========================================================================================
Input size (MB): 0.71
Forward/backward pass size (MB): 282.22
Params size (MB): 275.53
Estimated Total Size (MB): 558.47
==========================================================================================

Total mult-adds (G): 14.38 is of intrest to us (Giga-MACs = 1B (10⁹) MACs). So 28.76 GFLOPs.

Given that it is often assumed that backward pass has about twice the amount of FLOPs as the forward one, this checks out. 28.76 * 3 = 86.28, which is close to PyTorch’s FlopCounterMode output.

Option3: deepspeed

from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator

# with get_accelerator().device(0):
flops, macs, params = get_model_profile(
  model=wide_resnet,
  input_shape=input_shape,
  args=None,
  kwargs=None,
  print_profile=True,
  detailed=True,
  module_depth=-1,
  top_modules=1,
  warm_up=10,
  as_string=True,
  output_file=None,
  ignore_modules=None,
)

print(flops, macs, params)
print("Params:", params)
print("GMACs:", macs)
print("GFLOPs:", flops)
Params: 68.88 M
GMACs: 14.38 GMACs
GFLOPs: 28.81 G

Besides these values, deepspeed outputs quite a detailed profiling report.

Option4: fvcore

from fvcore.nn import FlopCountAnalysis

flops = FlopCountAnalysis(wide_resnet, input_tensor)
print("FLOPs: ", flops.total())
print(f"GFLOPs: {flops.total() / 1e9}")
Unsupported operator aten::add_ encountered 69 time(s)
Unsupported operator aten::max_pool2d encountered 1 time(s)
FLOPs:  14468464384
GFLOPs: 14.468464384

Well, interestingly, we got almost exactly as many FLOPs as MACs with torchinfo and deepspeed.
Encounter counter does not disappoint.

Option5: ptflops

from ptflops import get_model_complexity_info

macs, params = get_model_complexity_info(
  wide_resnet, input_shape[1:], as_strings=False, print_per_layer_stat=False, backend='pytorch'
)
print("Params:", params)
print("GMACs:", macs / 1e9)
print("GFLOPs:", (macs / 1e9) * 2)
Params: 68883240
GMACS: 14.44918756
GFLOPS: 28.89837512

Option6: flopth

from flopth import flopth

flops, params = flopth(wide_resnet, in_size=input_shape[1:])
print("Params:", params)
print("GFLOPs:", flops)
Params: 68.8832M
GFLOPs: 14.4242G

Option7: calflops

from calflops import calculate_flops

flops, macs, params = calculate_flops(
  model=wide_resnet, input_shape=input_shape, output_as_string=True, output_precision=4
)

print("Params:", params)
print("GMACs:", macs[:7])
print("GFLOPs:", flops[:7])
Params: 68.8832 M
GMACs: 14.3801
GFLOPs: 28.8124

Option8: thop

from thop import profile

macs, params = profile(wide_resnet, inputs=(input_tensor, ))
print("Params:", params)
print("GMACs:", macs / 1e9)
Params: 68883240.0
GMACs: 43.352484096

43 is quite a number when compared to other profilers.

If you’re looking for a dedicated third-party library, I would choose between deepspeed, calflops (Transformers), and fvcore (CV).

Other Resources