Advanced Machine Learning with Python (Session 2 - Part 1)

Fernando Cervantes (fernando.cervantes@jax.org)

Workshop outcomes

  • Understand the process of training ML models.
  • Load pre-trained ML models and fine-tune them with new data.
  • Evaluate the performance of ML models.
  • Adapt ML models for different tasks from pre-trained models.

Materials

Open notebook in Colab View solutions

0. Setup environment

Select runtime and connect

On the top right corner of the page, click the drop-down arrow to the right of the Connect button and select Change runtime type.

Make sure Python 3 runtime is selected. For this part of the workshop CPU acceleration is enough.

Now we can connect to the runtime by clicking Connect. This will create a Virtual Machine (VM) with compute resources we can use for a limited amount of time.

Caution

In free Colab accounts these resources are not guaranteed and can be taken away without notice (preemptible machines).

Data stored in this runtime will be lost if not moved into other storage when the runtime is deleted.

Load pre-trained models

Load pre-trained models

  • Lets use one from the PyTorch’s torchvision module for computer vision

  • Try first with the InceptionV3 model. InceptionV3

Exercise: Use a pre-trained deep learning model to classify images

import torch
from torchvision import models

inception_weights = models.inception.Inception_V3_Weights.IMAGENET1K_V1

inception_weights.meta
{'num_params': 27161264,
 'min_size': (75, 75),
 'categories': ['tench',
  'goldfish',
  'great white shark',
  'tiger shark',
  'hammerhead',
  'electric ray',
  'stingray',
  'cock',
  'hen',
  'ostrich',
  'brambling',
  'goldfinch',
  'house finch',
  'junco',
  'indigo bunting',
  'robin',
  'bulbul',
  'jay',
  'magpie',
  'chickadee',
  'water ouzel',
  'kite',
  'bald eagle',
  'vulture',
  'great grey owl',
  'European fire salamander',
  'common newt',
  'eft',
  'spotted salamander',
  'axolotl',
  'bullfrog',
  'tree frog',
  'tailed frog',
  'loggerhead',
  'leatherback turtle',
  'mud turtle',
  'terrapin',
  'box turtle',
  'banded gecko',
  'common iguana',
  'American chameleon',
  'whiptail',
  'agama',
  'frilled lizard',
  'alligator lizard',
  'Gila monster',
  'green lizard',
  'African chameleon',
  'Komodo dragon',
  'African crocodile',
  'American alligator',
  'triceratops',
  'thunder snake',
  'ringneck snake',
  'hognose snake',
  'green snake',
  'king snake',
  'garter snake',
  'water snake',
  'vine snake',
  'night snake',
  'boa constrictor',
  'rock python',
  'Indian cobra',
  'green mamba',
  'sea snake',
  'horned viper',
  'diamondback',
  'sidewinder',
  'trilobite',
  'harvestman',
  'scorpion',
  'black and gold garden spider',
  'barn spider',
  'garden spider',
  'black widow',
  'tarantula',
  'wolf spider',
  'tick',
  'centipede',
  'black grouse',
  'ptarmigan',
  'ruffed grouse',
  'prairie chicken',
  'peacock',
  'quail',
  'partridge',
  'African grey',
  'macaw',
  'sulphur-crested cockatoo',
  'lorikeet',
  'coucal',
  'bee eater',
  'hornbill',
  'hummingbird',
  'jacamar',
  'toucan',
  'drake',
  'red-breasted merganser',
  'goose',
  'black swan',
  'tusker',
  'echidna',
  'platypus',
  'wallaby',
  'koala',
  'wombat',
  'jellyfish',
  'sea anemone',
  'brain coral',
  'flatworm',
  'nematode',
  'conch',
  'snail',
  'slug',
  'sea slug',
  'chiton',
  'chambered nautilus',
  'Dungeness crab',
  'rock crab',
  'fiddler crab',
  'king crab',
  'American lobster',
  'spiny lobster',
  'crayfish',
  'hermit crab',
  'isopod',
  'white stork',
  'black stork',
  'spoonbill',
  'flamingo',
  'little blue heron',
  'American egret',
  'bittern',
  'crane bird',
  'limpkin',
  'European gallinule',
  'American coot',
  'bustard',
  'ruddy turnstone',
  'red-backed sandpiper',
  'redshank',
  'dowitcher',
  'oystercatcher',
  'pelican',
  'king penguin',
  'albatross',
  'grey whale',
  'killer whale',
  'dugong',
  'sea lion',
  'Chihuahua',
  'Japanese spaniel',
  'Maltese dog',
  'Pekinese',
  'Shih-Tzu',
  'Blenheim spaniel',
  'papillon',
  'toy terrier',
  'Rhodesian ridgeback',
  'Afghan hound',
  'basset',
  'beagle',
  'bloodhound',
  'bluetick',
  'black-and-tan coonhound',
  'Walker hound',
  'English foxhound',
  'redbone',
  'borzoi',
  'Irish wolfhound',
  'Italian greyhound',
  'whippet',
  'Ibizan hound',
  'Norwegian elkhound',
  'otterhound',
  'Saluki',
  'Scottish deerhound',
  'Weimaraner',
  'Staffordshire bullterrier',
  'American Staffordshire terrier',
  'Bedlington terrier',
  'Border terrier',
  'Kerry blue terrier',
  'Irish terrier',
  'Norfolk terrier',
  'Norwich terrier',
  'Yorkshire terrier',
  'wire-haired fox terrier',
  'Lakeland terrier',
  'Sealyham terrier',
  'Airedale',
  'cairn',
  'Australian terrier',
  'Dandie Dinmont',
  'Boston bull',
  'miniature schnauzer',
  'giant schnauzer',
  'standard schnauzer',
  'Scotch terrier',
  'Tibetan terrier',
  'silky terrier',
  'soft-coated wheaten terrier',
  'West Highland white terrier',
  'Lhasa',
  'flat-coated retriever',
  'curly-coated retriever',
  'golden retriever',
  'Labrador retriever',
  'Chesapeake Bay retriever',
  'German short-haired pointer',
  'vizsla',
  'English setter',
  'Irish setter',
  'Gordon setter',
  'Brittany spaniel',
  'clumber',
  'English springer',
  'Welsh springer spaniel',
  'cocker spaniel',
  'Sussex spaniel',
  'Irish water spaniel',
  'kuvasz',
  'schipperke',
  'groenendael',
  'malinois',
  'briard',
  'kelpie',
  'komondor',
  'Old English sheepdog',
  'Shetland sheepdog',
  'collie',
  'Border collie',
  'Bouvier des Flandres',
  'Rottweiler',
  'German shepherd',
  'Doberman',
  'miniature pinscher',
  'Greater Swiss Mountain dog',
  'Bernese mountain dog',
  'Appenzeller',
  'EntleBucher',
  'boxer',
  'bull mastiff',
  'Tibetan mastiff',
  'French bulldog',
  'Great Dane',
  'Saint Bernard',
  'Eskimo dog',
  'malamute',
  'Siberian husky',
  'dalmatian',
  'affenpinscher',
  'basenji',
  'pug',
  'Leonberg',
  'Newfoundland',
  'Great Pyrenees',
  'Samoyed',
  'Pomeranian',
  'chow',
  'keeshond',
  'Brabancon griffon',
  'Pembroke',
  'Cardigan',
  'toy poodle',
  'miniature poodle',
  'standard poodle',
  'Mexican hairless',
  'timber wolf',
  'white wolf',
  'red wolf',
  'coyote',
  'dingo',
  'dhole',
  'African hunting dog',
  'hyena',
  'red fox',
  'kit fox',
  'Arctic fox',
  'grey fox',
  'tabby',
  'tiger cat',
  'Persian cat',
  'Siamese cat',
  'Egyptian cat',
  'cougar',
  'lynx',
  'leopard',
  'snow leopard',
  'jaguar',
  'lion',
  'tiger',
  'cheetah',
  'brown bear',
  'American black bear',
  'ice bear',
  'sloth bear',
  'mongoose',
  'meerkat',
  'tiger beetle',
  'ladybug',
  'ground beetle',
  'long-horned beetle',
  'leaf beetle',
  'dung beetle',
  'rhinoceros beetle',
  'weevil',
  'fly',
  'bee',
  'ant',
  'grasshopper',
  'cricket',
  'walking stick',
  'cockroach',
  'mantis',
  'cicada',
  'leafhopper',
  'lacewing',
  'dragonfly',
  'damselfly',
  'admiral',
  'ringlet',
  'monarch',
  'cabbage butterfly',
  'sulphur butterfly',
  'lycaenid',
  'starfish',
  'sea urchin',
  'sea cucumber',
  'wood rabbit',
  'hare',
  'Angora',
  'hamster',
  'porcupine',
  'fox squirrel',
  'marmot',
  'beaver',
  'guinea pig',
  'sorrel',
  'zebra',
  'hog',
  'wild boar',
  'warthog',
  'hippopotamus',
  'ox',
  'water buffalo',
  'bison',
  'ram',
  'bighorn',
  'ibex',
  'hartebeest',
  'impala',
  'gazelle',
  'Arabian camel',
  'llama',
  'weasel',
  'mink',
  'polecat',
  'black-footed ferret',
  'otter',
  'skunk',
  'badger',
  'armadillo',
  'three-toed sloth',
  'orangutan',
  'gorilla',
  'chimpanzee',
  'gibbon',
  'siamang',
  'guenon',
  'patas',
  'baboon',
  'macaque',
  'langur',
  'colobus',
  'proboscis monkey',
  'marmoset',
  'capuchin',
  'howler monkey',
  'titi',
  'spider monkey',
  'squirrel monkey',
  'Madagascar cat',
  'indri',
  'Indian elephant',
  'African elephant',
  'lesser panda',
  'giant panda',
  'barracouta',
  'eel',
  'coho',
  'rock beauty',
  'anemone fish',
  'sturgeon',
  'gar',
  'lionfish',
  'puffer',
  'abacus',
  'abaya',
  'academic gown',
  'accordion',
  'acoustic guitar',
  'aircraft carrier',
  'airliner',
  'airship',
  'altar',
  'ambulance',
  'amphibian',
  'analog clock',
  'apiary',
  'apron',
  'ashcan',
  'assault rifle',
  'backpack',
  'bakery',
  'balance beam',
  'balloon',
  'ballpoint',
  'Band Aid',
  'banjo',
  'bannister',
  'barbell',
  'barber chair',
  'barbershop',
  'barn',
  'barometer',
  'barrel',
  'barrow',
  'baseball',
  'basketball',
  'bassinet',
  'bassoon',
  'bathing cap',
  'bath towel',
  'bathtub',
  'beach wagon',
  'beacon',
  'beaker',
  'bearskin',
  'beer bottle',
  'beer glass',
  'bell cote',
  'bib',
  'bicycle-built-for-two',
  'bikini',
  'binder',
  'binoculars',
  'birdhouse',
  'boathouse',
  'bobsled',
  'bolo tie',
  'bonnet',
  'bookcase',
  'bookshop',
  'bottlecap',
  'bow',
  'bow tie',
  'brass',
  'brassiere',
  'breakwater',
  'breastplate',
  'broom',
  'bucket',
  'buckle',
  'bulletproof vest',
  'bullet train',
  'butcher shop',
  'cab',
  'caldron',
  'candle',
  'cannon',
  'canoe',
  'can opener',
  'cardigan',
  'car mirror',
  'carousel',
  "carpenter's kit",
  'carton',
  'car wheel',
  'cash machine',
  'cassette',
  'cassette player',
  'castle',
  'catamaran',
  'CD player',
  'cello',
  'cellular telephone',
  'chain',
  'chainlink fence',
  'chain mail',
  'chain saw',
  'chest',
  'chiffonier',
  'chime',
  'china cabinet',
  'Christmas stocking',
  'church',
  'cinema',
  'cleaver',
  'cliff dwelling',
  'cloak',
  'clog',
  'cocktail shaker',
  'coffee mug',
  'coffeepot',
  'coil',
  'combination lock',
  'computer keyboard',
  'confectionery',
  'container ship',
  'convertible',
  'corkscrew',
  'cornet',
  'cowboy boot',
  'cowboy hat',
  'cradle',
  'crane',
  'crash helmet',
  'crate',
  'crib',
  'Crock Pot',
  'croquet ball',
  'crutch',
  'cuirass',
  'dam',
  'desk',
  'desktop computer',
  'dial telephone',
  'diaper',
  'digital clock',
  'digital watch',
  'dining table',
  'dishrag',
  'dishwasher',
  'disk brake',
  'dock',
  'dogsled',
  'dome',
  'doormat',
  'drilling platform',
  'drum',
  'drumstick',
  'dumbbell',
  'Dutch oven',
  'electric fan',
  'electric guitar',
  'electric locomotive',
  'entertainment center',
  'envelope',
  'espresso maker',
  'face powder',
  'feather boa',
  'file',
  'fireboat',
  'fire engine',
  'fire screen',
  'flagpole',
  'flute',
  'folding chair',
  'football helmet',
  'forklift',
  'fountain',
  'fountain pen',
  'four-poster',
  'freight car',
  'French horn',
  'frying pan',
  'fur coat',
  'garbage truck',
  'gasmask',
  'gas pump',
  'goblet',
  'go-kart',
  'golf ball',
  'golfcart',
  'gondola',
  'gong',
  'gown',
  'grand piano',
  'greenhouse',
  'grille',
  'grocery store',
  'guillotine',
  'hair slide',
  'hair spray',
  'half track',
  'hammer',
  'hamper',
  'hand blower',
  'hand-held computer',
  'handkerchief',
  'hard disc',
  'harmonica',
  'harp',
  'harvester',
  'hatchet',
  'holster',
  'home theater',
  'honeycomb',
  'hook',
  'hoopskirt',
  'horizontal bar',
  'horse cart',
  'hourglass',
  'iPod',
  'iron',
  "jack-o'-lantern",
  'jean',
  'jeep',
  'jersey',
  'jigsaw puzzle',
  'jinrikisha',
  'joystick',
  'kimono',
  'knee pad',
  'knot',
  'lab coat',
  'ladle',
  'lampshade',
  'laptop',
  'lawn mower',
  'lens cap',
  'letter opener',
  'library',
  'lifeboat',
  'lighter',
  'limousine',
  'liner',
  'lipstick',
  'Loafer',
  'lotion',
  'loudspeaker',
  'loupe',
  'lumbermill',
  'magnetic compass',
  'mailbag',
  'mailbox',
  'maillot',
  'maillot tank suit',
  'manhole cover',
  'maraca',
  'marimba',
  'mask',
  'matchstick',
  'maypole',
  'maze',
  'measuring cup',
  'medicine chest',
  'megalith',
  'microphone',
  'microwave',
  'military uniform',
  'milk can',
  'minibus',
  'miniskirt',
  'minivan',
  'missile',
  'mitten',
  'mixing bowl',
  'mobile home',
  'Model T',
  'modem',
  'monastery',
  'monitor',
  'moped',
  'mortar',
  'mortarboard',
  'mosque',
  'mosquito net',
  'motor scooter',
  'mountain bike',
  'mountain tent',
  'mouse',
  'mousetrap',
  'moving van',
  'muzzle',
  'nail',
  'neck brace',
  'necklace',
  'nipple',
  'notebook',
  'obelisk',
  'oboe',
  'ocarina',
  'odometer',
  'oil filter',
  'organ',
  'oscilloscope',
  'overskirt',
  'oxcart',
  'oxygen mask',
  'packet',
  'paddle',
  'paddlewheel',
  'padlock',
  'paintbrush',
  'pajama',
  'palace',
  'panpipe',
  'paper towel',
  'parachute',
  'parallel bars',
  'park bench',
  'parking meter',
  'passenger car',
  'patio',
  'pay-phone',
  'pedestal',
  'pencil box',
  'pencil sharpener',
  'perfume',
  'Petri dish',
  'photocopier',
  'pick',
  'pickelhaube',
  'picket fence',
  'pickup',
  'pier',
  'piggy bank',
  'pill bottle',
  'pillow',
  'ping-pong ball',
  'pinwheel',
  'pirate',
  'pitcher',
  'plane',
  'planetarium',
  'plastic bag',
  'plate rack',
  'plow',
  'plunger',
  'Polaroid camera',
  'pole',
  'police van',
  'poncho',
  'pool table',
  'pop bottle',
  'pot',
  "potter's wheel",
  'power drill',
  'prayer rug',
  'printer',
  'prison',
  'projectile',
  'projector',
  'puck',
  'punching bag',
  'purse',
  'quill',
  'quilt',
  'racer',
  'racket',
  'radiator',
  'radio',
  'radio telescope',
  'rain barrel',
  'recreational vehicle',
  'reel',
  'reflex camera',
  'refrigerator',
  'remote control',
  'restaurant',
  'revolver',
  'rifle',
  'rocking chair',
  'rotisserie',
  'rubber eraser',
  'rugby ball',
  'rule',
  'running shoe',
  'safe',
  'safety pin',
  'saltshaker',
  'sandal',
  'sarong',
  'sax',
  'scabbard',
  'scale',
  'school bus',
  'schooner',
  'scoreboard',
  'screen',
  'screw',
  'screwdriver',
  'seat belt',
  'sewing machine',
  'shield',
  'shoe shop',
  'shoji',
  'shopping basket',
  'shopping cart',
  'shovel',
  'shower cap',
  'shower curtain',
  'ski',
  'ski mask',
  'sleeping bag',
  'slide rule',
  'sliding door',
  'slot',
  'snorkel',
  'snowmobile',
  'snowplow',
  'soap dispenser',
  'soccer ball',
  'sock',
  'solar dish',
  'sombrero',
  'soup bowl',
  'space bar',
  'space heater',
  'space shuttle',
  'spatula',
  'speedboat',
  'spider web',
  'spindle',
  'sports car',
  'spotlight',
  'stage',
  'steam locomotive',
  'steel arch bridge',
  'steel drum',
  'stethoscope',
  'stole',
  'stone wall',
  'stopwatch',
  'stove',
  'strainer',
  'streetcar',
  'stretcher',
  'studio couch',
  'stupa',
  'submarine',
  'suit',
  'sundial',
  'sunglass',
  'sunglasses',
  'sunscreen',
  'suspension bridge',
  'swab',
  'sweatshirt',
  'swimming trunks',
  'swing',
  'switch',
  'syringe',
  'table lamp',
  'tank',
  'tape player',
  'teapot',
  'teddy',
  'television',
  'tennis ball',
  'thatch',
  'theater curtain',
  'thimble',
  'thresher',
  'throne',
  'tile roof',
  'toaster',
  'tobacco shop',
  'toilet seat',
  'torch',
  'totem pole',
  'tow truck',
  'toyshop',
  'tractor',
  'trailer truck',
  'tray',
  'trench coat',
  'tricycle',
  'trimaran',
  'tripod',
  'triumphal arch',
  'trolleybus',
  'trombone',
  'tub',
  'turnstile',
  'typewriter keyboard',
  'umbrella',
  'unicycle',
  'upright',
  'vacuum',
  'vase',
  'vault',
  'velvet',
  'vending machine',
  'vestment',
  'viaduct',
  'violin',
  'volleyball',
  'waffle iron',
  'wall clock',
  'wallet',
  'wardrobe',
  'warplane',
  'washbasin',
  'washer',
  'water bottle',
  'water jug',
  'water tower',
  'whiskey jug',
  'whistle',
  'wig',
  'window screen',
  'window shade',
  'Windsor tie',
  'wine bottle',
  'wing',
  'wok',
  'wooden spoon',
  'wool',
  'worm fence',
  'wreck',
  'yawl',
  'yurt',
  'web site',
  'comic book',
  'crossword puzzle',
  'street sign',
  'traffic light',
  'book jacket',
  'menu',
  'plate',
  'guacamole',
  'consomme',
  'hot pot',
  'trifle',
  'ice cream',
  'ice lolly',
  'French loaf',
  'bagel',
  'pretzel',
  'cheeseburger',
  'hotdog',
  'mashed potato',
  'head cabbage',
  'broccoli',
  'cauliflower',
  'zucchini',
  'spaghetti squash',
  'acorn squash',
  'butternut squash',
  'cucumber',
  'artichoke',
  'bell pepper',
  'cardoon',
  'mushroom',
  'Granny Smith',
  'strawberry',
  'orange',
  'lemon',
  'fig',
  'pineapple',
  'banana',
  'jackfruit',
  'custard apple',
  'pomegranate',
  'hay',
  'carbonara',
  'chocolate sauce',
  'dough',
  'meat loaf',
  'pizza',
  'potpie',
  'burrito',
  'red wine',
  'espresso',
  'cup',
  'eggnog',
  'alp',
  'bubble',
  'cliff',
  'coral reef',
  'geyser',
  'lakeside',
  'promontory',
  'sandbar',
  'seashore',
  'valley',
  'volcano',
  'ballplayer',
  'groom',
  'scuba diver',
  'rapeseed',
  'daisy',
  "yellow lady's slipper",
  'corn',
  'acorn',
  'hip',
  'buckeye',
  'coral fungus',
  'agaric',
  'gyromitra',
  'stinkhorn',
  'earthstar',
  'hen-of-the-woods',
  'bolete',
  'ear',
  'toilet tissue'],
 'recipe': 'https://github.com/pytorch/vision/tree/main/references/classification#inception-v3',
 '_metrics': {'ImageNet-1K': {'acc@1': 77.294, 'acc@5': 93.45}},
 '_ops': 5.713,
 '_file_size': 103.903,
 '_docs': 'These weights are ported from the original paper.'}
categories = inception_weights.meta["categories"]

Tip

More info about Inception V3 implementation in torchvision here

Exercise: Use a pre-trained deep learning model to classify images

dl_model = models.inception_v3(inception_weights, progress=True)

dl_model.eval()
Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Mixed_5b): InceptionA(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch5x5_1): BasicConv2d(
      (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch5x5_2): BasicConv2d(
      (conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_2): BasicConv2d(
      (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3): BasicConv2d(
      (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Mixed_5c): InceptionA(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch5x5_1): BasicConv2d(
      (conv): Conv2d(256, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch5x5_2): BasicConv2d(
      (conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_1): BasicConv2d(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_2): BasicConv2d(
      (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3): BasicConv2d(
      (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Mixed_5d): InceptionA(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch5x5_1): BasicConv2d(
      (conv): Conv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch5x5_2): BasicConv2d(
      (conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_1): BasicConv2d(
      (conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_2): BasicConv2d(
      (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3): BasicConv2d(
      (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Mixed_6a): InceptionB(
    (branch3x3): BasicConv2d(
      (conv): Conv2d(288, 384, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_1): BasicConv2d(
      (conv): Conv2d(288, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_2): BasicConv2d(
      (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3): BasicConv2d(
      (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Mixed_6b): InceptionC(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_1): BasicConv2d(
      (conv): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_2): BasicConv2d(
      (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_3): BasicConv2d(
      (conv): Conv2d(128, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_1): BasicConv2d(
      (conv): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_2): BasicConv2d(
      (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_3): BasicConv2d(
      (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_4): BasicConv2d(
      (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_5): BasicConv2d(
      (conv): Conv2d(128, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Mixed_6c): InceptionC(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_1): BasicConv2d(
      (conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_2): BasicConv2d(
      (conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_3): BasicConv2d(
      (conv): Conv2d(160, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_1): BasicConv2d(
      (conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_2): BasicConv2d(
      (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_3): BasicConv2d(
      (conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_4): BasicConv2d(
      (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_5): BasicConv2d(
      (conv): Conv2d(160, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Mixed_6d): InceptionC(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_1): BasicConv2d(
      (conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_2): BasicConv2d(
      (conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_3): BasicConv2d(
      (conv): Conv2d(160, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_1): BasicConv2d(
      (conv): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_2): BasicConv2d(
      (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_3): BasicConv2d(
      (conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_4): BasicConv2d(
      (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_5): BasicConv2d(
      (conv): Conv2d(160, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Mixed_6e): InceptionC(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_2): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7_3): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_2): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_3): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_4): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7dbl_5): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (AuxLogits): InceptionAux(
    (conv0): BasicConv2d(
      (conv): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv1): BasicConv2d(
      (conv): Conv2d(128, 768, kernel_size=(5, 5), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fc): Linear(in_features=768, out_features=1000, bias=True)
  )
  (Mixed_7a): InceptionD(
    (branch3x3_1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_2): BasicConv2d(
      (conv): Conv2d(192, 320, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7x3_1): BasicConv2d(
      (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7x3_2): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7x3_3): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch7x7x3_4): BasicConv2d(
      (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Mixed_7b): InceptionE(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_1): BasicConv2d(
      (conv): Conv2d(1280, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_2a): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_2b): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_1): BasicConv2d(
      (conv): Conv2d(1280, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(448, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_2): BasicConv2d(
      (conv): Conv2d(448, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3a): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3b): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(1280, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (Mixed_7c): InceptionE(
    (branch1x1): BasicConv2d(
      (conv): Conv2d(2048, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_1): BasicConv2d(
      (conv): Conv2d(2048, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_2a): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3_2b): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_1): BasicConv2d(
      (conv): Conv2d(2048, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(448, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_2): BasicConv2d(
      (conv): Conv2d(448, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3a): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch3x3dbl_3b): BasicConv2d(
      (conv): Conv2d(384, 384, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch_pool): BasicConv2d(
      (conv): Conv2d(2048, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

Exercise: Use a pre-trained deep learning model to classify images

import skimage
import matplotlib.pyplot as plt

sample_im = skimage.data.rocket()
sample_im.shape
(427, 640, 3)

Exercise: Use a pre-trained deep learning model to classify images

plt.imshow(sample_im)
plt.show()

Exercise: Use a pre-trained deep learning model to classify images

inception_weights.transforms
functools.partial(<class 'torchvision.transforms._presets.ImageClassification'>, crop_size=299, resize_size=342)

Important

functools.partial is a function to define functions with static arguments. So 👆 returns a function when it is called!

Note

The transforms used by the Inception V3 are

  1. resize the image to 342x342 pixels,

  2. crop the center 299x299 pixels window, and

  3. normalize the values of the RGB channels.

Exercise: Use a pretrained deep learning model to classify images

from torchvision.transforms.v2 import Compose, ToTensor

pipeline = Compose([
  ToTensor(),
  inception_weights.transforms()
])

pipeline
Compose(
      ToTensor()
      ImageClassification(
      crop_size=[299]
      resize_size=[342]
      mean=[0.485, 0.456, 0.406]
      std=[0.229, 0.224, 0.225]
      interpolation=InterpolationMode.BILINEAR
  )
)

Exercise: Use a pretrained deep learning model to classify images

sample_x = pipeline(sample_im)
type(sample_x), sample_x.shape, sample_x.min(), sample_x.max()
(torch.Tensor, torch.Size([3, 299, 299]), tensor(-1.8792), tensor(2.6400))

Exercise: Use a pretrained deep learning model to classify images

Caution

Apply the model on sample_x[None, …], so it is treated as a one-sample batch

sample_y = dl_model(sample_x[None, ...])

sample_y.shape
torch.Size([1, 1000])

Exercise: Use a pretrained deep learning model to classify images

Note

The model’s output are the log-probabilities of sample_x belonging to each of the 1000 classes.

sample_y.argsort(dim=1)
tensor([[578, 253, 339, 982, 301,  40, 584, 897, 382,  12, 170, 181,  25, 161,
         255, 162, 159,  14, 500,  72,  44, 368, 714, 211, 201, 240,  90, 277,
          99, 519,  46,  13, 790, 160, 883, 307,  31, 467, 271,  32, 452, 997,
          92, 414, 286,  64, 241, 661, 360, 381, 473, 999, 316,  24, 184, 213,
         804,  26, 889, 383, 689,  11, 231,  28, 207, 520, 134,  77,  37, 177,
         513, 539, 191, 264, 306,  19, 103, 529, 823, 171, 838, 247, 174, 400,
         228, 665, 669, 278,  16, 870, 354, 412, 272, 377, 175, 411, 551,  75,
         265, 390, 601, 929,  27, 543, 434, 268,  94, 200,  33, 504, 760, 614,
         387, 875, 167, 809, 588, 189,  30, 291, 703, 944, 323, 610, 631,  57,
          41, 881, 349, 938, 493, 825, 732, 249, 593, 299, 516, 237, 246, 195,
         137, 850, 756,  91, 695, 336, 799, 431, 766, 761, 310, 337, 894, 622,
         305, 129, 351, 989, 581, 991, 612, 110, 672, 436, 239, 226,  93,  76,
         785,  21, 260,  63, 294, 176, 960, 528, 606, 826, 168, 166,  87, 531,
         355, 154,  15, 774, 717, 322,  18,   0, 193, 435, 152, 678, 722, 269,
          86,  47, 292, 759, 248, 317, 105, 302, 621, 376,  60, 793, 295, 283,
         692, 720, 495, 859, 284, 393, 267, 596, 537, 656, 235, 831, 707, 444,
          85, 130, 232, 534, 560, 204, 188, 430, 379, 853,  43,  61, 261, 684,
         932, 172, 746, 178, 986, 876, 615, 996, 770, 700, 298, 459, 384, 273,
         328,  52, 605, 736, 968,  95, 410, 922, 296, 259, 347, 401, 518, 256,
         912, 937, 319,  97, 155, 463, 636, 457,  53, 671, 119, 893, 300, 281,
         921, 752, 394, 163, 244, 340, 697, 985, 763, 553, 243, 133, 568, 210,
         454, 933, 597, 443,  20, 308, 451, 579, 549, 164, 716, 917,  59, 592,
         136, 114, 185, 690, 472, 878, 642, 135, 771, 639, 276, 456, 485, 238,
         943, 582, 486, 544, 979,  42, 868, 511, 321, 156,   9, 721, 365, 113,
         263, 910,  88, 499, 670,  17, 635, 734, 266, 984, 386, 140, 280, 750,
         407,  10, 654, 197, 478,  36,  62, 775, 638,  55, 275, 861, 230, 421,
         603, 646, 447, 112, 233, 788, 222, 655, 492, 423, 896, 618, 675, 993,
         794, 616, 215, 915, 196, 402, 229, 852, 364,  22, 139,  81, 685, 713,
         309, 314, 153, 116, 547, 217, 643, 566, 719, 874,  98, 330, 587, 254,
         458, 623, 886, 901, 842, 541, 930, 817, 507, 946, 362, 849,  73, 786,
         939,  35, 157, 396, 250, 326, 725, 345, 617, 580, 988, 778, 552, 251,
         187, 865,  71, 887, 242, 651, 448,  23, 559, 359, 202, 145, 180, 857,
         691, 182, 475, 813, 743, 131, 843, 304, 318, 353, 422, 440,  89, 234,
         704, 706,  68, 728, 150, 395, 108, 391, 512, 425, 468, 947, 143, 293,
         350, 526, 257, 550, 957, 970, 955, 575,   5, 101, 123, 945, 633, 104,
          80, 502, 729, 508,  49, 225, 258, 426, 637, 183, 995, 789, 344, 906,
         948, 179, 950, 398, 397, 664, 532, 903, 198, 626, 194, 236, 667, 676,
         757, 465, 324, 335, 218, 303, 699, 496, 934, 158,  38, 441, 433, 992,
         375, 357, 810, 815, 693, 192, 585, 107, 586, 289, 331, 735, 772, 125,
         515, 416, 392, 482, 869, 608, 681, 607, 882, 640, 380, 126, 854, 924,
         768, 312, 589, 837, 141, 570, 315, 653, 102, 572, 479, 369, 956, 503,
         115, 481,  74, 797, 122, 535, 327, 352,  45, 953, 802,  48, 389,  96,
         270, 723, 224, 378, 795, 424, 450, 461, 796, 100,  39, 252, 800, 648,
         679, 121, 787, 453, 925, 613, 846, 726, 702, 173,   8, 221, 287, 926,
         169, 341, 325, 186, 128, 356, 709, 455, 645, 983,   6,  67, 449, 285,
         972,  65, 765, 564, 282, 663, 677, 118, 806,  29, 576, 151, 388, 748,
         782, 311, 358, 963, 329, 824, 333, 533, 203, 905, 673, 830, 951, 480,
         419, 445, 829, 127, 904, 858,  69, 371, 313, 208, 753, 509, 521, 641,
         219, 214, 209,  56, 928, 462, 138, 149, 776, 987, 523, 206, 916, 332,
          51,   1, 320, 262, 798, 524, 805, 514, 627, 274, 705, 334, 420, 227,
         773, 474, 205, 696, 899,  84, 212, 111,   7, 483, 801, 109, 990, 710,
         866, 488,  78, 791, 602, 851, 747, 971,  70,  82, 439, 848, 367, 873,
         591, 958, 505, 464, 342, 290, 674, 659, 898, 288, 803, 967, 594, 370,
         730, 385, 567, 779, 223, 374, 429, 649, 741, 836, 711,  50, 952, 165,
         598, 739, 546, 487, 686, 658, 749, 715, 501, 609, 890, 954, 432,  34,
         769, 428, 647, 973, 962, 742, 981,   4, 619, 892,  58, 841, 964,  83,
         142, 891, 361, 885, 909, 927, 124, 777, 363, 220, 373, 144, 406, 660,
         731, 297, 783, 630, 245, 577, 106, 583, 965, 911, 427, 556, 834, 132,
         762, 469, 569, 738, 199, 491, 446, 348,   3, 346, 147, 902, 808, 740,
         936, 908, 819, 372, 827, 624, 975, 687, 548, 879, 611,  79, 574, 662,
         476, 522, 864, 745, 969, 880, 998, 680, 343, 632, 884, 877, 497, 835,
         477, 701, 811, 573, 818, 120, 949, 565, 190, 767, 914, 650, 418, 931,
         216, 844, 872, 666, 415, 527, 117, 604, 148, 438, 923, 279, 599, 888,
         698, 961, 855, 941, 338, 978, 466, 538, 688, 724, 839, 792, 542, 652,
         525, 814, 918, 708, 366, 590, 860, 862, 600, 784, 942, 966, 976, 867,
           2, 146, 489, 413, 712, 595, 644, 828, 994,  54, 399, 629, 780, 727,
         822, 959, 764, 935, 832, 755, 751, 417, 977, 561, 558, 545, 625, 490,
          66, 470, 980, 974, 816, 571, 863, 506, 737, 494, 555, 907, 409, 840,
         460, 563, 471, 634, 620, 820, 484, 530, 913, 536, 403, 718, 856, 758,
         845, 919, 821, 562, 871, 920, 404, 408, 683, 781, 940, 510, 847, 498,
         628, 694, 895, 554, 405, 900, 442, 754, 833, 437, 807, 682, 668, 557,
         812, 733, 517, 540, 744, 657]])

Exercise: Use a pretrained deep learning model to classify images

sorted_predicted_classes = sample_y.argsort(dim=1, descending=True)[0, :10]
sorted_probs = torch.softmax(sample_y, dim=1)[0, sorted_predicted_classes]

for idx, prob in zip(sorted_predicted_classes, sorted_probs):
    print(categories[idx], "%3.2f %%" % (prob * 100))
missile 37.93 %
projectile 14.39 %
drilling platform 12.16 %
crane 5.72 %
pole 1.06 %
space shuttle 0.93 %
flagpole 0.93 %
mosque 0.51 %
obelisk 0.46 %
solar dish 0.45 %

Try with other sample images (only works with RGB!)

Try with other sample images

Caution

Only works with RGB images

sample_im = skimage.io.imread("https://upload.wikimedia.org/wikipedia/commons/thumb/c/c8/Black_Labrador_Retriever_-_Male_IMG_3323.jpg/1280px-Black_Labrador_Retriever_-_Male_IMG_3323.jpg")
sample_x = pipeline(sample_im)
sample_y = dl_model(sample_x[None, ...])

plt.imshow(sample_im)
plt.title(categories[sample_y.argmax(dim=1)])
plt.show()

sorted_predicted_classes = sample_y.argsort(dim=1, descending=True)[0, :10]
sorted_probs = torch.softmax(sample_y, dim=1)[0, sorted_predicted_classes]

for idx, prob in zip(sorted_predicted_classes, sorted_probs):
    print(categories[idx], "%3.2f %%" % (prob * 100))

Labrador retriever 87.15 %
German short-haired pointer 3.31 %
flat-coated retriever 0.92 %
Great Dane 0.83 %
black-and-tan coonhound 0.52 %
giant schnauzer 0.34 %
curly-coated retriever 0.22 %
Staffordshire bullterrier 0.19 %
Doberman 0.14 %
Rottweiler 0.07 %

Using Deep Learning models as feature extractors

Exercise: Modify the classifier layer dl_model.fc to return the features map from the input image instead of the category

Tip

The classifier layer is commonly implemented as a MultiLayer Perceptron (Fully connected) at the end of the models. The specific name of that layer can vary between implementations.

dl_extractor = models.inception_v3(inception_weights, progress=True)
dl_extractor.eval()

dl_extractor.fc
Linear(in_features=2048, out_features=1000, bias=True)

Exercise: Modify the classifier layer dl_model.fc to return the features map from the input image instead of the category

dl_extractor.fc = torch.nn.Identity()
sample_fx = dl_extractor(sample_x[None, ...])

sample_fx.shape
torch.Size([1, 2048])