Advanced Machine Learning with Python (Session 2 - Part 2)

Fernando Cervantes (fernando.cervantes@jax.org)

Materials

Open notebook in Colab

Working with Transformers

Review the architecture of a Vision Transformer (ViT)

Dosovitskiy, Alexey et al. “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.” ArXiv abs/2010.11929 (2020)
  • https://docs.pytorch.org/vision/stable/models/vision_transformer.html

Vaswani, Ashish et al. “Attention is All you Need.” Neural Information Processing Systems (2017).

Exercise: Use a pre-trained ViT model to classify images

import torch
from torchvision import models

transformer_weights = models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1

transformer_weights.meta
{'categories': ['tench',
  'goldfish',
  'great white shark',
  'tiger shark',
  'hammerhead',
  'electric ray',
  'stingray',
  'cock',
  'hen',
  'ostrich',
  'brambling',
  'goldfinch',
  'house finch',
  'junco',
  'indigo bunting',
  'robin',
  'bulbul',
  'jay',
  'magpie',
  'chickadee',
  'water ouzel',
  'kite',
  'bald eagle',
  'vulture',
  'great grey owl',
  'European fire salamander',
  'common newt',
  'eft',
  'spotted salamander',
  'axolotl',
  'bullfrog',
  'tree frog',
  'tailed frog',
  'loggerhead',
  'leatherback turtle',
  'mud turtle',
  'terrapin',
  'box turtle',
  'banded gecko',
  'common iguana',
  'American chameleon',
  'whiptail',
  'agama',
  'frilled lizard',
  'alligator lizard',
  'Gila monster',
  'green lizard',
  'African chameleon',
  'Komodo dragon',
  'African crocodile',
  'American alligator',
  'triceratops',
  'thunder snake',
  'ringneck snake',
  'hognose snake',
  'green snake',
  'king snake',
  'garter snake',
  'water snake',
  'vine snake',
  'night snake',
  'boa constrictor',
  'rock python',
  'Indian cobra',
  'green mamba',
  'sea snake',
  'horned viper',
  'diamondback',
  'sidewinder',
  'trilobite',
  'harvestman',
  'scorpion',
  'black and gold garden spider',
  'barn spider',
  'garden spider',
  'black widow',
  'tarantula',
  'wolf spider',
  'tick',
  'centipede',
  'black grouse',
  'ptarmigan',
  'ruffed grouse',
  'prairie chicken',
  'peacock',
  'quail',
  'partridge',
  'African grey',
  'macaw',
  'sulphur-crested cockatoo',
  'lorikeet',
  'coucal',
  'bee eater',
  'hornbill',
  'hummingbird',
  'jacamar',
  'toucan',
  'drake',
  'red-breasted merganser',
  'goose',
  'black swan',
  'tusker',
  'echidna',
  'platypus',
  'wallaby',
  'koala',
  'wombat',
  'jellyfish',
  'sea anemone',
  'brain coral',
  'flatworm',
  'nematode',
  'conch',
  'snail',
  'slug',
  'sea slug',
  'chiton',
  'chambered nautilus',
  'Dungeness crab',
  'rock crab',
  'fiddler crab',
  'king crab',
  'American lobster',
  'spiny lobster',
  'crayfish',
  'hermit crab',
  'isopod',
  'white stork',
  'black stork',
  'spoonbill',
  'flamingo',
  'little blue heron',
  'American egret',
  'bittern',
  'crane bird',
  'limpkin',
  'European gallinule',
  'American coot',
  'bustard',
  'ruddy turnstone',
  'red-backed sandpiper',
  'redshank',
  'dowitcher',
  'oystercatcher',
  'pelican',
  'king penguin',
  'albatross',
  'grey whale',
  'killer whale',
  'dugong',
  'sea lion',
  'Chihuahua',
  'Japanese spaniel',
  'Maltese dog',
  'Pekinese',
  'Shih-Tzu',
  'Blenheim spaniel',
  'papillon',
  'toy terrier',
  'Rhodesian ridgeback',
  'Afghan hound',
  'basset',
  'beagle',
  'bloodhound',
  'bluetick',
  'black-and-tan coonhound',
  'Walker hound',
  'English foxhound',
  'redbone',
  'borzoi',
  'Irish wolfhound',
  'Italian greyhound',
  'whippet',
  'Ibizan hound',
  'Norwegian elkhound',
  'otterhound',
  'Saluki',
  'Scottish deerhound',
  'Weimaraner',
  'Staffordshire bullterrier',
  'American Staffordshire terrier',
  'Bedlington terrier',
  'Border terrier',
  'Kerry blue terrier',
  'Irish terrier',
  'Norfolk terrier',
  'Norwich terrier',
  'Yorkshire terrier',
  'wire-haired fox terrier',
  'Lakeland terrier',
  'Sealyham terrier',
  'Airedale',
  'cairn',
  'Australian terrier',
  'Dandie Dinmont',
  'Boston bull',
  'miniature schnauzer',
  'giant schnauzer',
  'standard schnauzer',
  'Scotch terrier',
  'Tibetan terrier',
  'silky terrier',
  'soft-coated wheaten terrier',
  'West Highland white terrier',
  'Lhasa',
  'flat-coated retriever',
  'curly-coated retriever',
  'golden retriever',
  'Labrador retriever',
  'Chesapeake Bay retriever',
  'German short-haired pointer',
  'vizsla',
  'English setter',
  'Irish setter',
  'Gordon setter',
  'Brittany spaniel',
  'clumber',
  'English springer',
  'Welsh springer spaniel',
  'cocker spaniel',
  'Sussex spaniel',
  'Irish water spaniel',
  'kuvasz',
  'schipperke',
  'groenendael',
  'malinois',
  'briard',
  'kelpie',
  'komondor',
  'Old English sheepdog',
  'Shetland sheepdog',
  'collie',
  'Border collie',
  'Bouvier des Flandres',
  'Rottweiler',
  'German shepherd',
  'Doberman',
  'miniature pinscher',
  'Greater Swiss Mountain dog',
  'Bernese mountain dog',
  'Appenzeller',
  'EntleBucher',
  'boxer',
  'bull mastiff',
  'Tibetan mastiff',
  'French bulldog',
  'Great Dane',
  'Saint Bernard',
  'Eskimo dog',
  'malamute',
  'Siberian husky',
  'dalmatian',
  'affenpinscher',
  'basenji',
  'pug',
  'Leonberg',
  'Newfoundland',
  'Great Pyrenees',
  'Samoyed',
  'Pomeranian',
  'chow',
  'keeshond',
  'Brabancon griffon',
  'Pembroke',
  'Cardigan',
  'toy poodle',
  'miniature poodle',
  'standard poodle',
  'Mexican hairless',
  'timber wolf',
  'white wolf',
  'red wolf',
  'coyote',
  'dingo',
  'dhole',
  'African hunting dog',
  'hyena',
  'red fox',
  'kit fox',
  'Arctic fox',
  'grey fox',
  'tabby',
  'tiger cat',
  'Persian cat',
  'Siamese cat',
  'Egyptian cat',
  'cougar',
  'lynx',
  'leopard',
  'snow leopard',
  'jaguar',
  'lion',
  'tiger',
  'cheetah',
  'brown bear',
  'American black bear',
  'ice bear',
  'sloth bear',
  'mongoose',
  'meerkat',
  'tiger beetle',
  'ladybug',
  'ground beetle',
  'long-horned beetle',
  'leaf beetle',
  'dung beetle',
  'rhinoceros beetle',
  'weevil',
  'fly',
  'bee',
  'ant',
  'grasshopper',
  'cricket',
  'walking stick',
  'cockroach',
  'mantis',
  'cicada',
  'leafhopper',
  'lacewing',
  'dragonfly',
  'damselfly',
  'admiral',
  'ringlet',
  'monarch',
  'cabbage butterfly',
  'sulphur butterfly',
  'lycaenid',
  'starfish',
  'sea urchin',
  'sea cucumber',
  'wood rabbit',
  'hare',
  'Angora',
  'hamster',
  'porcupine',
  'fox squirrel',
  'marmot',
  'beaver',
  'guinea pig',
  'sorrel',
  'zebra',
  'hog',
  'wild boar',
  'warthog',
  'hippopotamus',
  'ox',
  'water buffalo',
  'bison',
  'ram',
  'bighorn',
  'ibex',
  'hartebeest',
  'impala',
  'gazelle',
  'Arabian camel',
  'llama',
  'weasel',
  'mink',
  'polecat',
  'black-footed ferret',
  'otter',
  'skunk',
  'badger',
  'armadillo',
  'three-toed sloth',
  'orangutan',
  'gorilla',
  'chimpanzee',
  'gibbon',
  'siamang',
  'guenon',
  'patas',
  'baboon',
  'macaque',
  'langur',
  'colobus',
  'proboscis monkey',
  'marmoset',
  'capuchin',
  'howler monkey',
  'titi',
  'spider monkey',
  'squirrel monkey',
  'Madagascar cat',
  'indri',
  'Indian elephant',
  'African elephant',
  'lesser panda',
  'giant panda',
  'barracouta',
  'eel',
  'coho',
  'rock beauty',
  'anemone fish',
  'sturgeon',
  'gar',
  'lionfish',
  'puffer',
  'abacus',
  'abaya',
  'academic gown',
  'accordion',
  'acoustic guitar',
  'aircraft carrier',
  'airliner',
  'airship',
  'altar',
  'ambulance',
  'amphibian',
  'analog clock',
  'apiary',
  'apron',
  'ashcan',
  'assault rifle',
  'backpack',
  'bakery',
  'balance beam',
  'balloon',
  'ballpoint',
  'Band Aid',
  'banjo',
  'bannister',
  'barbell',
  'barber chair',
  'barbershop',
  'barn',
  'barometer',
  'barrel',
  'barrow',
  'baseball',
  'basketball',
  'bassinet',
  'bassoon',
  'bathing cap',
  'bath towel',
  'bathtub',
  'beach wagon',
  'beacon',
  'beaker',
  'bearskin',
  'beer bottle',
  'beer glass',
  'bell cote',
  'bib',
  'bicycle-built-for-two',
  'bikini',
  'binder',
  'binoculars',
  'birdhouse',
  'boathouse',
  'bobsled',
  'bolo tie',
  'bonnet',
  'bookcase',
  'bookshop',
  'bottlecap',
  'bow',
  'bow tie',
  'brass',
  'brassiere',
  'breakwater',
  'breastplate',
  'broom',
  'bucket',
  'buckle',
  'bulletproof vest',
  'bullet train',
  'butcher shop',
  'cab',
  'caldron',
  'candle',
  'cannon',
  'canoe',
  'can opener',
  'cardigan',
  'car mirror',
  'carousel',
  "carpenter's kit",
  'carton',
  'car wheel',
  'cash machine',
  'cassette',
  'cassette player',
  'castle',
  'catamaran',
  'CD player',
  'cello',
  'cellular telephone',
  'chain',
  'chainlink fence',
  'chain mail',
  'chain saw',
  'chest',
  'chiffonier',
  'chime',
  'china cabinet',
  'Christmas stocking',
  'church',
  'cinema',
  'cleaver',
  'cliff dwelling',
  'cloak',
  'clog',
  'cocktail shaker',
  'coffee mug',
  'coffeepot',
  'coil',
  'combination lock',
  'computer keyboard',
  'confectionery',
  'container ship',
  'convertible',
  'corkscrew',
  'cornet',
  'cowboy boot',
  'cowboy hat',
  'cradle',
  'crane',
  'crash helmet',
  'crate',
  'crib',
  'Crock Pot',
  'croquet ball',
  'crutch',
  'cuirass',
  'dam',
  'desk',
  'desktop computer',
  'dial telephone',
  'diaper',
  'digital clock',
  'digital watch',
  'dining table',
  'dishrag',
  'dishwasher',
  'disk brake',
  'dock',
  'dogsled',
  'dome',
  'doormat',
  'drilling platform',
  'drum',
  'drumstick',
  'dumbbell',
  'Dutch oven',
  'electric fan',
  'electric guitar',
  'electric locomotive',
  'entertainment center',
  'envelope',
  'espresso maker',
  'face powder',
  'feather boa',
  'file',
  'fireboat',
  'fire engine',
  'fire screen',
  'flagpole',
  'flute',
  'folding chair',
  'football helmet',
  'forklift',
  'fountain',
  'fountain pen',
  'four-poster',
  'freight car',
  'French horn',
  'frying pan',
  'fur coat',
  'garbage truck',
  'gasmask',
  'gas pump',
  'goblet',
  'go-kart',
  'golf ball',
  'golfcart',
  'gondola',
  'gong',
  'gown',
  'grand piano',
  'greenhouse',
  'grille',
  'grocery store',
  'guillotine',
  'hair slide',
  'hair spray',
  'half track',
  'hammer',
  'hamper',
  'hand blower',
  'hand-held computer',
  'handkerchief',
  'hard disc',
  'harmonica',
  'harp',
  'harvester',
  'hatchet',
  'holster',
  'home theater',
  'honeycomb',
  'hook',
  'hoopskirt',
  'horizontal bar',
  'horse cart',
  'hourglass',
  'iPod',
  'iron',
  "jack-o'-lantern",
  'jean',
  'jeep',
  'jersey',
  'jigsaw puzzle',
  'jinrikisha',
  'joystick',
  'kimono',
  'knee pad',
  'knot',
  'lab coat',
  'ladle',
  'lampshade',
  'laptop',
  'lawn mower',
  'lens cap',
  'letter opener',
  'library',
  'lifeboat',
  'lighter',
  'limousine',
  'liner',
  'lipstick',
  'Loafer',
  'lotion',
  'loudspeaker',
  'loupe',
  'lumbermill',
  'magnetic compass',
  'mailbag',
  'mailbox',
  'maillot',
  'maillot tank suit',
  'manhole cover',
  'maraca',
  'marimba',
  'mask',
  'matchstick',
  'maypole',
  'maze',
  'measuring cup',
  'medicine chest',
  'megalith',
  'microphone',
  'microwave',
  'military uniform',
  'milk can',
  'minibus',
  'miniskirt',
  'minivan',
  'missile',
  'mitten',
  'mixing bowl',
  'mobile home',
  'Model T',
  'modem',
  'monastery',
  'monitor',
  'moped',
  'mortar',
  'mortarboard',
  'mosque',
  'mosquito net',
  'motor scooter',
  'mountain bike',
  'mountain tent',
  'mouse',
  'mousetrap',
  'moving van',
  'muzzle',
  'nail',
  'neck brace',
  'necklace',
  'nipple',
  'notebook',
  'obelisk',
  'oboe',
  'ocarina',
  'odometer',
  'oil filter',
  'organ',
  'oscilloscope',
  'overskirt',
  'oxcart',
  'oxygen mask',
  'packet',
  'paddle',
  'paddlewheel',
  'padlock',
  'paintbrush',
  'pajama',
  'palace',
  'panpipe',
  'paper towel',
  'parachute',
  'parallel bars',
  'park bench',
  'parking meter',
  'passenger car',
  'patio',
  'pay-phone',
  'pedestal',
  'pencil box',
  'pencil sharpener',
  'perfume',
  'Petri dish',
  'photocopier',
  'pick',
  'pickelhaube',
  'picket fence',
  'pickup',
  'pier',
  'piggy bank',
  'pill bottle',
  'pillow',
  'ping-pong ball',
  'pinwheel',
  'pirate',
  'pitcher',
  'plane',
  'planetarium',
  'plastic bag',
  'plate rack',
  'plow',
  'plunger',
  'Polaroid camera',
  'pole',
  'police van',
  'poncho',
  'pool table',
  'pop bottle',
  'pot',
  "potter's wheel",
  'power drill',
  'prayer rug',
  'printer',
  'prison',
  'projectile',
  'projector',
  'puck',
  'punching bag',
  'purse',
  'quill',
  'quilt',
  'racer',
  'racket',
  'radiator',
  'radio',
  'radio telescope',
  'rain barrel',
  'recreational vehicle',
  'reel',
  'reflex camera',
  'refrigerator',
  'remote control',
  'restaurant',
  'revolver',
  'rifle',
  'rocking chair',
  'rotisserie',
  'rubber eraser',
  'rugby ball',
  'rule',
  'running shoe',
  'safe',
  'safety pin',
  'saltshaker',
  'sandal',
  'sarong',
  'sax',
  'scabbard',
  'scale',
  'school bus',
  'schooner',
  'scoreboard',
  'screen',
  'screw',
  'screwdriver',
  'seat belt',
  'sewing machine',
  'shield',
  'shoe shop',
  'shoji',
  'shopping basket',
  'shopping cart',
  'shovel',
  'shower cap',
  'shower curtain',
  'ski',
  'ski mask',
  'sleeping bag',
  'slide rule',
  'sliding door',
  'slot',
  'snorkel',
  'snowmobile',
  'snowplow',
  'soap dispenser',
  'soccer ball',
  'sock',
  'solar dish',
  'sombrero',
  'soup bowl',
  'space bar',
  'space heater',
  'space shuttle',
  'spatula',
  'speedboat',
  'spider web',
  'spindle',
  'sports car',
  'spotlight',
  'stage',
  'steam locomotive',
  'steel arch bridge',
  'steel drum',
  'stethoscope',
  'stole',
  'stone wall',
  'stopwatch',
  'stove',
  'strainer',
  'streetcar',
  'stretcher',
  'studio couch',
  'stupa',
  'submarine',
  'suit',
  'sundial',
  'sunglass',
  'sunglasses',
  'sunscreen',
  'suspension bridge',
  'swab',
  'sweatshirt',
  'swimming trunks',
  'swing',
  'switch',
  'syringe',
  'table lamp',
  'tank',
  'tape player',
  'teapot',
  'teddy',
  'television',
  'tennis ball',
  'thatch',
  'theater curtain',
  'thimble',
  'thresher',
  'throne',
  'tile roof',
  'toaster',
  'tobacco shop',
  'toilet seat',
  'torch',
  'totem pole',
  'tow truck',
  'toyshop',
  'tractor',
  'trailer truck',
  'tray',
  'trench coat',
  'tricycle',
  'trimaran',
  'tripod',
  'triumphal arch',
  'trolleybus',
  'trombone',
  'tub',
  'turnstile',
  'typewriter keyboard',
  'umbrella',
  'unicycle',
  'upright',
  'vacuum',
  'vase',
  'vault',
  'velvet',
  'vending machine',
  'vestment',
  'viaduct',
  'violin',
  'volleyball',
  'waffle iron',
  'wall clock',
  'wallet',
  'wardrobe',
  'warplane',
  'washbasin',
  'washer',
  'water bottle',
  'water jug',
  'water tower',
  'whiskey jug',
  'whistle',
  'wig',
  'window screen',
  'window shade',
  'Windsor tie',
  'wine bottle',
  'wing',
  'wok',
  'wooden spoon',
  'wool',
  'worm fence',
  'wreck',
  'yawl',
  'yurt',
  'web site',
  'comic book',
  'crossword puzzle',
  'street sign',
  'traffic light',
  'book jacket',
  'menu',
  'plate',
  'guacamole',
  'consomme',
  'hot pot',
  'trifle',
  'ice cream',
  'ice lolly',
  'French loaf',
  'bagel',
  'pretzel',
  'cheeseburger',
  'hotdog',
  'mashed potato',
  'head cabbage',
  'broccoli',
  'cauliflower',
  'zucchini',
  'spaghetti squash',
  'acorn squash',
  'butternut squash',
  'cucumber',
  'artichoke',
  'bell pepper',
  'cardoon',
  'mushroom',
  'Granny Smith',
  'strawberry',
  'orange',
  'lemon',
  'fig',
  'pineapple',
  'banana',
  'jackfruit',
  'custard apple',
  'pomegranate',
  'hay',
  'carbonara',
  'chocolate sauce',
  'dough',
  'meat loaf',
  'pizza',
  'potpie',
  'burrito',
  'red wine',
  'espresso',
  'cup',
  'eggnog',
  'alp',
  'bubble',
  'cliff',
  'coral reef',
  'geyser',
  'lakeside',
  'promontory',
  'sandbar',
  'seashore',
  'valley',
  'volcano',
  'ballplayer',
  'groom',
  'scuba diver',
  'rapeseed',
  'daisy',
  "yellow lady's slipper",
  'corn',
  'acorn',
  'hip',
  'buckeye',
  'coral fungus',
  'agaric',
  'gyromitra',
  'stinkhorn',
  'earthstar',
  'hen-of-the-woods',
  'bolete',
  'ear',
  'toilet tissue'],
 'recipe': 'https://github.com/facebookresearch/SWAG',
 'license': 'https://github.com/facebookresearch/SWAG/blob/main/LICENSE',
 'num_params': 86859496,
 'min_size': (384, 384),
 '_metrics': {'ImageNet-1K': {'acc@1': 85.304, 'acc@5': 97.65}},
 '_ops': 55.484,
 '_file_size': 331.398,
 '_docs': '\n                These weights are learnt via transfer learning by end-to-end fine-tuning the original\n                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.\n            '}
categories = transformer_weights.meta["categories"]

Tip

More info about Vision Transformer implementation in torchvision here

dl_model = models.vit_b_16(transformer_weights, progress=True)

dl_model.eval()
VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_2): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_3): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_4): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_5): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_6): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_7): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_8): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_9): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_10): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_11): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  )
  (heads): Sequential(
    (head): Linear(in_features=768, out_features=1000, bias=True)
  )
)
import skimage
import matplotlib.pyplot as plt

sample_im = skimage.data.rocket()
sample_im.shape
(427, 640, 3)
plt.imshow(sample_im)
plt.show()

transformer_weights.transforms
functools.partial(<class 'torchvision.transforms._presets.ImageClassification'>, crop_size=384, resize_size=384, interpolation=<InterpolationMode.BICUBIC: 'bicubic'>)

Important

functools.partial is a function to define functions with static arguments. So 👆 returns a function when it is called!

Note

The transforms used by the Inception V3 are

  1. resize the image to 384x384 pixels, and

  2. normalize the values of the RGB channels.

from torchvision.transforms.v2 import Compose, ToTensor

pipeline = Compose([
  ToTensor(),
  transformer_weights.transforms()
])

pipeline
Compose(
      ToTensor()
      ImageClassification(
      crop_size=[384]
      resize_size=[384]
      mean=[0.485, 0.456, 0.406]
      std=[0.229, 0.224, 0.225]
      interpolation=InterpolationMode.BICUBIC
  )
)
sample_x = pipeline(sample_im)
type(sample_x), sample_x.shape, sample_x.min(), sample_x.max()
(torch.Tensor, torch.Size([3, 384, 384]), tensor(-1.9139), tensor(2.8376))

Caution

Apply the model on sample_x[None, …], so it is treated as a one-sample batch

sample_y = dl_model(sample_x[None, ...])

sample_y.shape
torch.Size([1, 1000])
sample_y.argsort(dim=1)
tensor([[641, 646, 398, 926, 340, 429, 990, 645, 924, 469, 346, 885, 351, 116,
         199,  62, 134, 661, 263, 678, 972, 594, 825, 976,  39, 397, 295, 880,
         793, 396, 853, 973, 817, 347, 433, 227, 975,  84,  32, 160, 841, 993,
         521, 462, 463, 544, 724, 518, 899,  25,  29, 832, 738, 982, 858, 290,
         939, 238, 500, 334, 659, 306, 329, 869, 488,  48,  93, 639, 834, 624,
         941, 843, 529, 966, 999, 343,  68, 801, 121, 809, 567, 373, 870,   0,
         197, 323, 522, 652, 725, 479, 297,  50, 573, 390, 381, 350, 823,  45,
         115, 344, 925, 515, 224, 865, 268, 235, 314, 706, 439, 173, 772, 453,
         214, 177, 883, 301, 419, 406, 554,  82, 150,  30,  85,   7, 655, 308,
         226, 849, 712, 369, 868, 828, 612, 460, 552, 109, 935, 959, 266,  88,
         244, 225, 400, 703, 259, 997, 964, 808, 302, 236, 709, 223,   5, 157,
          86, 649, 371, 179, 576, 928, 987, 960, 187,  36,  49, 417, 949, 291,
         762,  42, 328, 723, 305, 281, 673, 275,  98, 205, 349, 601, 854,  27,
           4, 449, 905, 392, 278, 904, 988, 366, 294, 245, 276, 292, 472, 216,
         933, 144, 989, 971, 963, 670,  72, 643, 714, 955, 401,  76, 108, 937,
         456, 512, 154, 304,  99, 289, 336, 775, 118, 133, 931, 836, 342, 962,
         385, 983,  41, 608, 260, 961, 474, 246, 330, 752, 690, 310, 611, 717,
         944, 459, 156, 345, 665, 602, 147,  75,  16, 208, 148, 539, 892, 207,
         280, 286,  92, 996, 648, 654, 981, 124, 787, 715, 663, 241,  23, 543,
         923, 909, 130, 277, 230, 288, 153, 948, 119, 879, 495, 943, 986, 211,
         233, 578, 790, 198, 331,  83, 175,  15, 575, 676, 234, 875, 741, 496,
         172, 627, 679, 353,  12, 421, 797, 658, 947, 167, 932, 531, 721, 969,
         307, 824, 237,  28, 897, 102, 504, 778, 719,  61, 919, 756, 248, 911,
         368,  17, 906, 584, 991, 729, 861, 430, 671,  14, 243, 887, 555, 874,
         322, 376, 100, 293, 958,  37, 616, 922, 501, 901,  38, 774, 170, 229,
         274, 339, 151,  67, 620, 326, 122, 965, 272,  51, 138, 193, 860, 936,
         434, 386, 873, 418, 432, 794, 666, 127, 445, 623, 713, 750, 212, 375,
         113, 549, 852, 638, 161, 614, 642,  20, 951, 270, 577, 635, 327, 364,
         940, 766, 640, 231, 796, 475, 917,  47, 128, 332, 783,  69, 186, 204,
         367, 748, 312,  31,  26, 215, 174,  71, 443, 180, 913, 866, 910, 468,
          60, 159, 104, 210, 416, 607, 300, 194, 945, 736, 415,  97, 884,  73,
         968, 391, 710, 771, 533, 192, 572,  53, 815, 651, 303, 735, 393,  64,
         566, 387, 746,  34, 165, 731,  66,  18, 399, 355, 155,  33, 129,  10,
         930, 956, 209, 110, 389, 934,  95, 499, 760, 363, 610, 598, 256, 335,
         213, 502, 196, 271, 444, 705,  55, 548, 927, 374, 980, 618, 316, 219,
         698, 979,  57, 361, 785,  21, 603, 337, 220, 921, 546, 876, 506, 158,
         169, 647, 842, 242, 262, 162,  91, 838, 770, 454, 452,  80, 324, 938,
         888, 136, 542, 672, 112, 617, 534, 768, 283, 864, 143, 953, 348, 633,
         183, 526, 101, 886, 178, 605, 560, 221, 253,  90, 734, 896,  94, 889,
         903, 946, 630, 749, 106, 568,  43, 918, 464,  54, 878, 287, 476, 510,
         751,  77, 315, 789, 593, 486, 321, 379, 667, 282, 569, 513, 985,  79,
         788, 636, 716, 826, 942, 370, 269, 821, 163, 240, 107, 320, 200, 509,
         492, 952, 780, 586, 455, 114, 228, 103, 265, 831, 325, 377, 564, 791,
         684,  58,  70, 388, 247, 420, 722, 457, 970, 126,  59,  24, 541, 252,
         779, 239, 915,  44, 805, 111, 296, 422, 466, 490, 362, 311, 436, 309,
         264, 957, 195, 358, 168,   9, 137, 588, 890, 333, 835, 257, 258, 284,
         669, 249,  74, 201, 483, 950, 992, 580, 123,  46, 395, 357, 820, 916,
         813,   3,  78, 769,  22, 806, 814, 802, 255, 285, 508, 382, 261, 697,
         535, 203, 467,  52, 338, 810, 319,  81,  96, 352, 458, 767, 545, 739,
         764, 446, 582, 596, 587, 726, 893, 800,  35, 441, 273, 675, 570, 176,
         125, 480, 747, 435, 538, 581,   6,  13, 105, 354, 171, 859, 728, 407,
         149,   1, 613, 622, 317, 978,  40, 424, 528, 699, 360, 411, 146, 135,
         152, 365, 839, 359, 621, 696, 798, 732, 254, 743,  87, 730, 190, 579,
         609, 591, 117, 777, 120, 977, 685, 489, 461, 615, 758, 318, 737, 140,
         556, 425, 589, 998, 447, 442,  65, 691, 600, 428, 995, 855, 299, 164,
         218, 487, 776, 759, 384, 394, 511, 142, 188, 631, 356,  89, 465, 491,
         881, 202, 984, 380, 520, 590, 830, 827, 550, 505, 250, 232, 383, 372,
         131, 891,   2, 378, 524, 403, 819, 423, 857, 850, 745, 574, 493, 954,
         516, 547, 485, 478, 829, 902, 139, 626, 681, 840, 689, 592, 438,  19,
         413, 537,  11, 298, 145, 251, 217, 822, 781, 816, 660, 597, 222, 530,
         862,  63, 664, 967,   8, 604, 166, 532, 184, 206, 686, 481, 182, 440,
         410, 628, 562, 551, 700, 599, 451, 848, 687, 804, 606, 753, 693, 267,
         313,  56, 414, 677, 701, 448, 482, 402, 692, 527, 141, 727, 629, 662,
         863, 929, 837, 803, 795, 784, 132, 656, 341, 702, 898, 683, 912, 695,
         470, 914, 625, 426, 974, 559, 740, 497, 742, 525, 191, 846, 711, 786,
         792, 409, 507, 757, 498, 181, 619, 718, 523, 704, 856, 427, 765, 484,
         707, 553, 847, 185, 694, 920, 907, 650, 680, 595, 450, 477, 563, 473,
         408, 189, 585, 514, 844, 994, 763, 634, 412, 708, 894, 536, 845, 851,
         773, 761, 471, 279, 871, 688, 431, 644, 674, 682, 867, 782, 811, 799,
         720, 558, 494, 877, 668, 833, 583, 637, 503, 557, 882, 632, 561, 895,
         653, 519, 908, 404, 571, 405, 565, 818, 733, 872, 517, 807, 437, 755,
         900, 540, 754, 744, 657, 812]])

Note

The model’s output are the log-probabilities of sample_x belonging to each of the 1000 classes.

sorted_predicted_classes = sample_y.argsort(dim=1, descending=True)[0, :10]
sorted_probs = torch.softmax(sample_y, dim=1)[0, sorted_predicted_classes]

for idx, prob in zip(sorted_predicted_classes, sorted_probs):
    print(categories[idx], "%3.2f %%" % (prob * 100))
space shuttle 79.91 %
missile 12.59 %
projectile 6.04 %
radio 0.11 %
drilling platform 0.08 %
water tower 0.05 %
radio telescope 0.05 %
beacon 0.04 %
solar dish 0.02 %
crane 0.02 %

Inspect the self-attention operations of the ViT

dl_model.encoder.layers[-1]
EncoderBlock(
  (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (self_attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (mlp): MLPBlock(
    (0): Linear(in_features=768, out_features=3072, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=3072, out_features=768, bias=True)
    (4): Dropout(p=0.0, inplace=False)
  )
)

Integrate a mechanism to review the attention maps of the model

Redefine some of transformer operations to enable storing the attention weights

from functools import partial
from typing import Callable
from timm.models.vision_transformer import Attention

class EncoderBlockAttnMap(models.vision_transformer.EncoderBlock):
    """Transformer encoder block."""

    def __init__(self,
                 num_heads: int,
                 hidden_dim: int,
                 mlp_dim: int,
                 dropout: float,
                 attention_dropout: float,
                 norm_layer: Callable[..., torch.nn.Module] = partial(torch.nn.LayerNorm, eps=1e-6)):
        # The definition is the same, only the forward function changes <------------------------------------
        super(EncoderBlockAttnMap, self).__init__(num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, norm_layer)
        self.self_attention = Attention(hidden_dim, num_heads, attn_drop=attention_dropout, proj_drop=0.0, norm_layer=norm_layer, qkv_bias=True)

    def forward(self, input: torch.Tensor):
        # with torch.autograd.graph.save_on_cpu(pin_memory=True):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)

        # Modify this line, so we get the attention map from the self attention modules <--------------------
        x = self.self_attention(x)

        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)

        # Return the attention map along with the encoder output <-------------------------------------------
        return x + y
from collections import OrderedDict

class EncoderAttnMap(models.vision_transformer.Encoder):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(torch.nn.LayerNorm, eps=1e-6),
    ):
        super().__init__(seq_length, num_layers, num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, norm_layer)

        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            # Use the modified encoder block <---------------------------------------------------------------
            layers[f"encoder_layer_{i}"] = EncoderBlockAttnMap(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )

        self.layers = torch.nn.Sequential(layers)
# Redefine the classifier head to have access to the attention maps
class ViTAttnEnabled(models.vision_transformer.VisionTransformer):
    """Implementation of the classifier head from the ViT-B-16 architecture.
    """
    def __init__(self, image_size, patch_size=14, num_layers=32, num_heads=16, hidden_dim=1280, mlp_dim=5120, **kwargs):      
        super(ViTAttnEnabled, self).__init__(
            image_size,
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            **kwargs)

        # Change the encoder to the modified ekwargsoder that returns the attention maps <-----------
        self.encoder = EncoderAttnMap(
            self.seq_length,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            dropout=0,
            attention_dropout=0,
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6)
        )

        self.attentions = []
        self.attentions_gradients = []

    def get_attention(self, module, input, output):
        self.attentions.append(output.detach())

    def get_attention_gradients(self, module, grad_input, grad_output):
        self.attentions_gradients.append(grad_input[0].detach())

    def register_attn_grad_hooks(self):
        for name, module in self.named_modules():
            if "self_attention.norm" in name:
                module.register_forward_hook(self.get_attention)
                module.register_full_backward_hook(self.get_attention_gradients)

    def clear_attentions(self):
        self.attentions.clear()
        self.attentions_gradients.clear()

Use a modified ViT model that enables tracking its attention weights

  • [] Initialize a ViT with the Vit-B-16 architecture
vit_model_self_attn = ViTAttnEnabled(
        image_size=384,
        patch_size=16,
        num_heads=12,
        num_layers=12,
        hidden_dim=768,
        mlp_dim=3072)
name_map = {
    "in_proj_weight": "qkv.weight",
    "in_proj_bias": "qkv.bias",
    "out_proj.weight": "proj.weight",
    "out_proj.bias": "proj.bias"
}

transformer_weights_dict = transformer_weights.get_state_dict(progress=True)

vit_weights = {}
for k, v in transformer_weights_dict.items():
    k_old = list(filter(lambda n: n in k, name_map.keys()))
    if len(k_old):
        k_old = k_old[0]
        old_name = k.split(k_old)[0]
        new_name = old_name + name_map[k_old]
    else:
        new_name = k

    vit_weights[new_name] = v

vit_model_self_attn.load_state_dict(vit_weights)
<All keys matched successfully>
vit_model_self_attn.register_attn_grad_hooks()

Apply the ViT class prediction on an image and compute the corresponding attention map

sample_im = skimage.io.imread("https://r0k.us/graphics/kodak/kodak/kodim20.png")

vit_model_self_attn.clear_attentions()

sample_x = pipeline(sample_im)

with torch.autograd.graph.save_on_cpu(pin_memory=True):
    sample_y = vit_model_self_attn(sample_x[None, ...])

attn_class = torch.argmax(sample_y, dim=1).item()
attn_class = torch.LongTensor([attn_class])
attn_class = torch.nn.functional.one_hot(attn_class, num_classes=sample_y.shape[1])

attn_class = torch.sum(attn_class * sample_y)

vit_model_self_attn.zero_grad()
attn_class.backward()

attn_out = [attn_tensor.clone() for attn_tensor in vit_model_self_attn.attentions]
grad_attn_out = [attn_tensor.clone() for attn_tensor in vit_model_self_attn.attentions_gradients]

Roll out the attention maps

attn_rollout = torch.eye(attn_out[0].size(1))[None, ...]

for attn_map, attn_grad in zip(attn_out, grad_attn_out):
    if attn_grad is not None:
        attn_map = attn_map * attn_grad
        attn_map[attn_map < 0] = 0

    attn_map, _ = torch.topk(attn_map, 10, dim=-1)
    attn_map = attn_map.mean(dim=-1)

    # Normalize attention map
    attn_map = attn_map + torch.eye(attn_map.size(1), device=attn_map.device)[None, ...]
    attn_map = attn_map / attn_map.sum(dim=-1, keepdim=True)

    attn_rollout = torch.matmul(attn_map, attn_rollout)

attn_rollout.shape
torch.Size([1, 577, 577])
attn_rollout = attn_rollout[:, :1, 1:]
attn_rollout = attn_rollout.reshape(1, -1, 24 ** 2)
attn_rollout.shape
torch.Size([1, 1, 576])
attn_rollout = torch.nn.functional.fold(attn_rollout.transpose(1, 2),
                      (24, 24),
                      kernel_size=(24, 24),
                      stride=(24, 24))
attn_rollout = attn_rollout.squeeze()

attn_rollout = attn_rollout / torch.max(attn_rollout)

attn_rollout.shape
torch.Size([24, 24])

Visualize the attention map computed from the attention weights

  • [] Show an overlay of the attention map over the original image
plt.imshow(sample_im)
plt.imshow(attn_rollout, cmap="magma", extent=(0, sample_im.shape[1], sample_im.shape[0], 0), alpha=0.75)