Advanced Machine Learning with Python (Session 2 - Part 3)

Fernando Cervantes (fernando.cervantes@jax.org)

Materials

Open notebook in Colab

Setup

!pip install s3fs imagecodecs umap-learn torchmetrics
!git clone https://github.com/jump-cellpainting/JUMP-Target
Cloning into 'JUMP-Target'...
remote: Enumerating objects: 243, done.
remote: Counting objects:   1% (1/78)remote: Counting objects:   2% (2/78)remote: Counting objects:   3% (3/78)remote: Counting objects:   5% (4/78)remote: Counting objects:   6% (5/78)remote: Counting objects:   7% (6/78)remote: Counting objects:   8% (7/78)remote: Counting objects:  10% (8/78)remote: Counting objects:  11% (9/78)remote: Counting objects:  12% (10/78)remote: Counting objects:  14% (11/78)remote: Counting objects:  15% (12/78)remote: Counting objects:  16% (13/78)remote: Counting objects:  17% (14/78)remote: Counting objects:  19% (15/78)remote: Counting objects:  20% (16/78)remote: Counting objects:  21% (17/78)remote: Counting objects:  23% (18/78)remote: Counting objects:  24% (19/78)remote: Counting objects:  25% (20/78)remote: Counting objects:  26% (21/78)remote: Counting objects:  28% (22/78)remote: Counting objects:  29% (23/78)remote: Counting objects:  30% (24/78)remote: Counting objects:  32% (25/78)remote: Counting objects:  33% (26/78)remote: Counting objects:  34% (27/78)remote: Counting objects:  35% (28/78)remote: Counting objects:  37% (29/78)remote: Counting objects:  38% (30/78)remote: Counting objects:  39% (31/78)remote: Counting objects:  41% (32/78)remote: Counting objects:  42% (33/78)remote: Counting objects:  43% (34/78)remote: Counting objects:  44% (35/78)remote: Counting objects:  46% (36/78)remote: Counting objects:  47% (37/78)remote: Counting objects:  48% (38/78)remote: Counting objects:  50% (39/78)remote: Counting objects:  51% (40/78)remote: Counting objects:  52% (41/78)remote: Counting objects:  53% (42/78)remote: Counting objects:  55% (43/78)remote: Counting objects:  56% (44/78)remote: Counting objects:  57% (45/78)remote: Counting objects:  58% (46/78)remote: Counting objects:  60% (47/78)remote: Counting objects:  61% (48/78)remote: Counting objects:  62% (49/78)remote: Counting objects:  64% (50/78)remote: Counting objects:  65% (51/78)remote: Counting objects:  66% (52/78)remote: Counting objects:  67% (53/78)remote: Counting objects:  69% (54/78)remote: Counting objects:  70% (55/78)remote: Counting objects:  71% (56/78)remote: Counting objects:  73% (57/78)remote: Counting objects:  74% (58/78)remote: Counting objects:  75% (59/78)remote: Counting objects:  76% (60/78)remote: Counting objects:  78% (61/78)remote: Counting objects:  79% (62/78)remote: Counting objects:  80% (63/78)remote: Counting objects:  82% (64/78)remote: Counting objects:  83% (65/78)remote: Counting objects:  84% (66/78)remote: Counting objects:  85% (67/78)remote: Counting objects:  87% (68/78)remote: Counting objects:  88% (69/78)remote: Counting objects:  89% (70/78)remote: Counting objects:  91% (71/78)remote: Counting objects:  92% (72/78)remote: Counting objects:  93% (73/78)remote: Counting objects:  94% (74/78)remote: Counting objects:  96% (75/78)remote: Counting objects:  97% (76/78)remote: Counting objects:  98% (77/78)remote: Counting objects: 100% (78/78)remote: Counting objects: 100% (78/78), done.
remote: Compressing objects:   1% (1/69)remote: Compressing objects:   2% (2/69)remote: Compressing objects:   4% (3/69)remote: Compressing objects:   5% (4/69)remote: Compressing objects:   7% (5/69)remote: Compressing objects:   8% (6/69)remote: Compressing objects:  10% (7/69)remote: Compressing objects:  11% (8/69)remote: Compressing objects:  13% (9/69)remote: Compressing objects:  14% (10/69)remote: Compressing objects:  15% (11/69)remote: Compressing objects:  17% (12/69)remote: Compressing objects:  18% (13/69)remote: Compressing objects:  20% (14/69)remote: Compressing objects:  21% (15/69)remote: Compressing objects:  23% (16/69)remote: Compressing objects:  24% (17/69)remote: Compressing objects:  26% (18/69)remote: Compressing objects:  27% (19/69)remote: Compressing objects:  28% (20/69)remote: Compressing objects:  30% (21/69)remote: Compressing objects:  31% (22/69)remote: Compressing objects:  33% (23/69)remote: Compressing objects:  34% (24/69)remote: Compressing objects:  36% (25/69)remote: Compressing objects:  37% (26/69)remote: Compressing objects:  39% (27/69)remote: Compressing objects:  40% (28/69)remote: Compressing objects:  42% (29/69)remote: Compressing objects:  43% (30/69)remote: Compressing objects:  44% (31/69)remote: Compressing objects:  46% (32/69)remote: Compressing objects:  47% (33/69)remote: Compressing objects:  49% (34/69)remote: Compressing objects:  50% (35/69)remote: Compressing objects:  52% (36/69)remote: Compressing objects:  53% (37/69)remote: Compressing objects:  55% (38/69)remote: Compressing objects:  56% (39/69)remote: Compressing objects:  57% (40/69)remote: Compressing objects:  59% (41/69)remote: Compressing objects:  60% (42/69)remote: Compressing objects:  62% (43/69)remote: Compressing objects:  63% (44/69)remote: Compressing objects:  65% (45/69)remote: Compressing objects:  66% (46/69)remote: Compressing objects:  68% (47/69)remote: Compressing objects:  69% (48/69)remote: Compressing objects:  71% (49/69)remote: Compressing objects:  72% (50/69)remote: Compressing objects:  73% (51/69)remote: Compressing objects:  75% (52/69)remote: Compressing objects:  76% (53/69)remote: Compressing objects:  78% (54/69)remote: Compressing objects:  79% (55/69)remote: Compressing objects:  81% (56/69)remote: Compressing objects:  82% (57/69)remote: Compressing objects:  84% (58/69)remote: Compressing objects:  85% (59/69)remote: Compressing objects:  86% (60/69)remote: Compressing objects:  88% (61/69)remote: Compressing objects:  89% (62/69)remote: Compressing objects:  91% (63/69)remote: Compressing objects:  92% (64/69)remote: Compressing objects:  94% (65/69)remote: Compressing objects:  95% (66/69)remote: Compressing objects:  97% (67/69)remote: Compressing objects:  98% (68/69)remote: Compressing objects: 100% (69/69)remote: Compressing objects: 100% (69/69), done.
Receiving objects:   0% (1/243)Receiving objects:   1% (3/243)Receiving objects:   2% (5/243)Receiving objects:   3% (8/243)Receiving objects:   4% (10/243)Receiving objects:   5% (13/243)Receiving objects:   6% (15/243)Receiving objects:   7% (18/243)Receiving objects:   8% (20/243)Receiving objects:   9% (22/243)Receiving objects:  10% (25/243)Receiving objects:  11% (27/243)Receiving objects:  12% (30/243)Receiving objects:  13% (32/243)Receiving objects:  14% (35/243)Receiving objects:  15% (37/243)Receiving objects:  16% (39/243)Receiving objects:  17% (42/243)Receiving objects:  18% (44/243)Receiving objects:  19% (47/243)Receiving objects:  20% (49/243)Receiving objects:  21% (52/243)Receiving objects:  22% (54/243)Receiving objects:  23% (56/243)Receiving objects:  24% (59/243)Receiving objects:  25% (61/243)Receiving objects:  26% (64/243)Receiving objects:  27% (66/243)Receiving objects:  28% (69/243)Receiving objects:  29% (71/243)Receiving objects:  30% (73/243)Receiving objects:  31% (76/243)Receiving objects:  32% (78/243)Receiving objects:  33% (81/243)Receiving objects:  34% (83/243)Receiving objects:  35% (86/243)Receiving objects:  36% (88/243)Receiving objects:  37% (90/243)Receiving objects:  38% (93/243)Receiving objects:  39% (95/243)Receiving objects:  40% (98/243)Receiving objects:  41% (100/243)Receiving objects:  42% (103/243)Receiving objects:  43% (105/243)Receiving objects:  44% (107/243)Receiving objects:  45% (110/243)Receiving objects:  46% (112/243)Receiving objects:  47% (115/243)Receiving objects:  48% (117/243)Receiving objects:  49% (120/243)Receiving objects:  50% (122/243)Receiving objects:  51% (124/243)Receiving objects:  52% (127/243)Receiving objects:  53% (129/243)Receiving objects:  54% (132/243)Receiving objects:  55% (134/243)Receiving objects:  56% (137/243)Receiving objects:  57% (139/243)Receiving objects:  58% (141/243)Receiving objects:  59% (144/243)Receiving objects:  60% (146/243)Receiving objects:  61% (149/243)Receiving objects:  62% (151/243)Receiving objects:  63% (154/243)Receiving objects:  64% (156/243)Receiving objects:  65% (158/243)Receiving objects:  66% (161/243)Receiving objects:  67% (163/243)Receiving objects:  68% (166/243)Receiving objects:  69% (168/243)Receiving objects:  70% (171/243)Receiving objects:  71% (173/243)Receiving objects:  72% (175/243)Receiving objects:  73% (178/243)Receiving objects:  74% (180/243)Receiving objects:  75% (183/243)Receiving objects:  76% (185/243)Receiving objects:  77% (188/243)Receiving objects:  78% (190/243)Receiving objects:  79% (192/243)Receiving objects:  80% (195/243)Receiving objects:  81% (197/243)Receiving objects:  82% (200/243)Receiving objects:  83% (202/243)Receiving objects:  84% (205/243)Receiving objects:  85% (207/243)Receiving objects:  86% (209/243)Receiving objects:  87% (212/243)Receiving objects:  88% (214/243)Receiving objects:  89% (217/243)Receiving objects:  90% (219/243)Receiving objects:  91% (222/243)Receiving objects:  92% (224/243)Receiving objects:  93% (226/243)Receiving objects:  94% (229/243)Receiving objects:  95% (231/243)remote: Total 243 (delta 40), reused 18 (delta 9), pack-reused 165 (from 1)
Receiving objects:  96% (234/243)Receiving objects:  97% (236/243)Receiving objects:  98% (239/243)Receiving objects:  99% (241/243)Receiving objects: 100% (243/243)Receiving objects: 100% (243/243), 289.62 KiB | 8.27 MiB/s, done.
Resolving deltas:   0% (0/133)Resolving deltas:   1% (2/133)Resolving deltas:   2% (3/133)Resolving deltas:   3% (4/133)Resolving deltas:   4% (6/133)Resolving deltas:   5% (7/133)Resolving deltas:   6% (8/133)Resolving deltas:   7% (10/133)Resolving deltas:   8% (11/133)Resolving deltas:   9% (12/133)Resolving deltas:  10% (14/133)Resolving deltas:  11% (15/133)Resolving deltas:  12% (16/133)Resolving deltas:  13% (18/133)Resolving deltas:  14% (19/133)Resolving deltas:  15% (20/133)Resolving deltas:  16% (22/133)Resolving deltas:  17% (23/133)Resolving deltas:  18% (24/133)Resolving deltas:  19% (26/133)Resolving deltas:  20% (27/133)Resolving deltas:  21% (28/133)Resolving deltas:  22% (30/133)Resolving deltas:  23% (31/133)Resolving deltas:  24% (32/133)Resolving deltas:  25% (34/133)Resolving deltas:  26% (35/133)Resolving deltas:  27% (36/133)Resolving deltas:  28% (38/133)Resolving deltas:  29% (39/133)Resolving deltas:  30% (40/133)Resolving deltas:  31% (42/133)Resolving deltas:  32% (43/133)Resolving deltas:  33% (44/133)Resolving deltas:  34% (46/133)Resolving deltas:  35% (47/133)Resolving deltas:  36% (48/133)Resolving deltas:  37% (50/133)Resolving deltas:  38% (51/133)Resolving deltas:  39% (52/133)Resolving deltas:  40% (54/133)Resolving deltas:  41% (55/133)Resolving deltas:  42% (56/133)Resolving deltas:  43% (58/133)Resolving deltas:  44% (59/133)Resolving deltas:  45% (60/133)Resolving deltas:  46% (62/133)Resolving deltas:  47% (63/133)Resolving deltas:  48% (64/133)Resolving deltas:  49% (66/133)Resolving deltas:  50% (67/133)Resolving deltas:  51% (68/133)Resolving deltas:  52% (70/133)Resolving deltas:  53% (71/133)Resolving deltas:  54% (72/133)Resolving deltas:  55% (74/133)Resolving deltas:  56% (75/133)Resolving deltas:  57% (76/133)Resolving deltas:  58% (78/133)Resolving deltas:  59% (79/133)Resolving deltas:  60% (80/133)Resolving deltas:  61% (82/133)Resolving deltas:  62% (83/133)Resolving deltas:  63% (84/133)Resolving deltas:  64% (86/133)Resolving deltas:  65% (87/133)Resolving deltas:  66% (88/133)Resolving deltas:  67% (90/133)Resolving deltas:  68% (91/133)Resolving deltas:  69% (92/133)Resolving deltas:  70% (94/133)Resolving deltas:  71% (95/133)Resolving deltas:  72% (96/133)Resolving deltas:  73% (98/133)Resolving deltas:  75% (100/133)Resolving deltas:  76% (102/133)Resolving deltas:  77% (103/133)Resolving deltas:  78% (104/133)Resolving deltas:  79% (106/133)Resolving deltas:  80% (107/133)Resolving deltas:  81% (108/133)Resolving deltas:  82% (110/133)Resolving deltas:  83% (111/133)Resolving deltas:  84% (112/133)Resolving deltas:  85% (114/133)Resolving deltas:  86% (115/133)Resolving deltas:  87% (116/133)Resolving deltas:  88% (118/133)Resolving deltas:  89% (119/133)Resolving deltas:  90% (120/133)Resolving deltas:  91% (122/133)Resolving deltas:  92% (123/133)Resolving deltas:  93% (124/133)Resolving deltas:  94% (126/133)Resolving deltas:  95% (127/133)Resolving deltas:  96% (128/133)Resolving deltas:  97% (130/133)Resolving deltas:  98% (131/133)Resolving deltas:  99% (132/133)Resolving deltas: 100% (133/133)Resolving deltas: 100% (133/133), done.
!git clone https://github.com/jump-cellpainting/datasets.git
Cloning into 'datasets'...
remote: Enumerating objects: 880, done.
remote: Counting objects:   0% (1/451)remote: Counting objects:   1% (5/451)remote: Counting objects:   2% (10/451)remote: Counting objects:   3% (14/451)remote: Counting objects:   4% (19/451)remote: Counting objects:   5% (23/451)remote: Counting objects:   6% (28/451)remote: Counting objects:   7% (32/451)remote: Counting objects:   8% (37/451)remote: Counting objects:   9% (41/451)remote: Counting objects:  10% (46/451)remote: Counting objects:  11% (50/451)remote: Counting objects:  12% (55/451)remote: Counting objects:  13% (59/451)remote: Counting objects:  14% (64/451)remote: Counting objects:  15% (68/451)remote: Counting objects:  16% (73/451)remote: Counting objects:  17% (77/451)remote: Counting objects:  18% (82/451)remote: Counting objects:  19% (86/451)remote: Counting objects:  20% (91/451)remote: Counting objects:  21% (95/451)remote: Counting objects:  22% (100/451)remote: Counting objects:  23% (104/451)remote: Counting objects:  24% (109/451)remote: Counting objects:  25% (113/451)remote: Counting objects:  26% (118/451)remote: Counting objects:  27% (122/451)remote: Counting objects:  28% (127/451)remote: Counting objects:  29% (131/451)remote: Counting objects:  30% (136/451)remote: Counting objects:  31% (140/451)remote: Counting objects:  32% (145/451)remote: Counting objects:  33% (149/451)remote: Counting objects:  34% (154/451)remote: Counting objects:  35% (158/451)remote: Counting objects:  36% (163/451)remote: Counting objects:  37% (167/451)remote: Counting objects:  38% (172/451)remote: Counting objects:  39% (176/451)remote: Counting objects:  40% (181/451)remote: Counting objects:  41% (185/451)remote: Counting objects:  42% (190/451)remote: Counting objects:  43% (194/451)remote: Counting objects:  44% (199/451)remote: Counting objects:  45% (203/451)remote: Counting objects:  46% (208/451)remote: Counting objects:  47% (212/451)remote: Counting objects:  48% (217/451)remote: Counting objects:  49% (221/451)remote: Counting objects:  50% (226/451)remote: Counting objects:  51% (231/451)remote: Counting objects:  52% (235/451)remote: Counting objects:  53% (240/451)remote: Counting objects:  54% (244/451)remote: Counting objects:  55% (249/451)remote: Counting objects:  56% (253/451)remote: Counting objects:  57% (258/451)remote: Counting objects:  58% (262/451)remote: Counting objects:  59% (267/451)remote: Counting objects:  60% (271/451)remote: Counting objects:  61% (276/451)remote: Counting objects:  62% (280/451)remote: Counting objects:  63% (285/451)remote: Counting objects:  64% (289/451)remote: Counting objects:  65% (294/451)remote: Counting objects:  66% (298/451)remote: Counting objects:  67% (303/451)remote: Counting objects:  68% (307/451)remote: Counting objects:  69% (312/451)remote: Counting objects:  70% (316/451)remote: Counting objects:  71% (321/451)remote: Counting objects:  72% (325/451)remote: Counting objects:  73% (330/451)remote: Counting objects:  74% (334/451)remote: Counting objects:  75% (339/451)remote: Counting objects:  76% (343/451)remote: Counting objects:  77% (348/451)remote: Counting objects:  78% (352/451)remote: Counting objects:  79% (357/451)remote: Counting objects:  80% (361/451)remote: Counting objects:  81% (366/451)remote: Counting objects:  82% (370/451)remote: Counting objects:  83% (375/451)remote: Counting objects:  84% (379/451)remote: Counting objects:  85% (384/451)remote: Counting objects:  86% (388/451)remote: Counting objects:  87% (393/451)remote: Counting objects:  88% (397/451)remote: Counting objects:  89% (402/451)remote: Counting objects:  90% (406/451)remote: Counting objects:  91% (411/451)remote: Counting objects:  92% (415/451)remote: Counting objects:  93% (420/451)remote: Counting objects:  94% (424/451)remote: Counting objects:  95% (429/451)remote: Counting objects:  96% (433/451)remote: Counting objects:  97% (438/451)remote: Counting objects:  98% (442/451)remote: Counting objects:  99% (447/451)remote: Counting objects: 100% (451/451)remote: Counting objects: 100% (451/451), done.
remote: Compressing objects:   0% (1/166)remote: Compressing objects:   1% (2/166)remote: Compressing objects:   2% (4/166)remote: Compressing objects:   3% (5/166)remote: Compressing objects:   4% (7/166)remote: Compressing objects:   5% (9/166)remote: Compressing objects:   6% (10/166)remote: Compressing objects:   7% (12/166)remote: Compressing objects:   8% (14/166)remote: Compressing objects:   9% (15/166)remote: Compressing objects:  10% (17/166)remote: Compressing objects:  11% (19/166)remote: Compressing objects:  12% (20/166)remote: Compressing objects:  13% (22/166)remote: Compressing objects:  14% (24/166)remote: Compressing objects:  15% (25/166)remote: Compressing objects:  16% (27/166)remote: Compressing objects:  17% (29/166)remote: Compressing objects:  18% (30/166)remote: Compressing objects:  19% (32/166)remote: Compressing objects:  20% (34/166)remote: Compressing objects:  21% (35/166)remote: Compressing objects:  22% (37/166)remote: Compressing objects:  23% (39/166)remote: Compressing objects:  24% (40/166)remote: Compressing objects:  25% (42/166)remote: Compressing objects:  26% (44/166)remote: Compressing objects:  27% (45/166)remote: Compressing objects:  28% (47/166)remote: Compressing objects:  29% (49/166)remote: Compressing objects:  30% (50/166)remote: Compressing objects:  31% (52/166)remote: Compressing objects:  32% (54/166)remote: Compressing objects:  33% (55/166)remote: Compressing objects:  34% (57/166)remote: Compressing objects:  35% (59/166)remote: Compressing objects:  36% (60/166)remote: Compressing objects:  37% (62/166)remote: Compressing objects:  38% (64/166)remote: Compressing objects:  39% (65/166)remote: Compressing objects:  40% (67/166)remote: Compressing objects:  41% (69/166)remote: Compressing objects:  42% (70/166)remote: Compressing objects:  43% (72/166)remote: Compressing objects:  44% (74/166)remote: Compressing objects:  45% (75/166)remote: Compressing objects:  46% (77/166)remote: Compressing objects:  47% (79/166)remote: Compressing objects:  48% (80/166)remote: Compressing objects:  49% (82/166)remote: Compressing objects:  50% (83/166)remote: Compressing objects:  51% (85/166)remote: Compressing objects:  52% (87/166)remote: Compressing objects:  53% (88/166)remote: Compressing objects:  54% (90/166)remote: Compressing objects:  55% (92/166)remote: Compressing objects:  56% (93/166)remote: Compressing objects:  57% (95/166)remote: Compressing objects:  58% (97/166)remote: Compressing objects:  59% (98/166)remote: Compressing objects:  60% (100/166)remote: Compressing objects:  61% (102/166)remote: Compressing objects:  62% (103/166)remote: Compressing objects:  63% (105/166)remote: Compressing objects:  64% (107/166)remote: Compressing objects:  65% (108/166)remote: Compressing objects:  66% (110/166)remote: Compressing objects:  67% (112/166)remote: Compressing objects:  68% (113/166)remote: Compressing objects:  69% (115/166)remote: Compressing objects:  70% (117/166)remote: Compressing objects:  71% (118/166)remote: Compressing objects:  72% (120/166)remote: Compressing objects:  73% (122/166)remote: Compressing objects:  74% (123/166)remote: Compressing objects:  75% (125/166)remote: Compressing objects:  76% (127/166)remote: Compressing objects:  77% (128/166)remote: Compressing objects:  78% (130/166)remote: Compressing objects:  79% (132/166)remote: Compressing objects:  80% (133/166)remote: Compressing objects:  81% (135/166)remote: Compressing objects:  82% (137/166)remote: Compressing objects:  83% (138/166)remote: Compressing objects:  84% (140/166)remote: Compressing objects:  85% (142/166)remote: Compressing objects:  86% (143/166)remote: Compressing objects:  87% (145/166)remote: Compressing objects:  88% (147/166)remote: Compressing objects:  89% (148/166)remote: Compressing objects:  90% (150/166)remote: Compressing objects:  91% (152/166)remote: Compressing objects:  92% (153/166)remote: Compressing objects:  93% (155/166)remote: Compressing objects:  94% (157/166)remote: Compressing objects:  95% (158/166)remote: Compressing objects:  96% (160/166)remote: Compressing objects:  97% (162/166)remote: Compressing objects:  98% (163/166)remote: Compressing objects:  99% (165/166)remote: Compressing objects: 100% (166/166)remote: Compressing objects: 100% (166/166), done.
Receiving objects:   0% (1/880)Receiving objects:   1% (9/880)Receiving objects:   2% (18/880)Receiving objects:   3% (27/880)Receiving objects:   4% (36/880)Receiving objects:   5% (44/880)Receiving objects:   6% (53/880)Receiving objects:   7% (62/880)Receiving objects:   8% (71/880)Receiving objects:   9% (80/880)Receiving objects:  10% (88/880)Receiving objects:  11% (97/880)Receiving objects:  12% (106/880)Receiving objects:  13% (115/880)Receiving objects:  14% (124/880)Receiving objects:  15% (132/880)Receiving objects:  16% (141/880)Receiving objects:  17% (150/880)Receiving objects:  18% (159/880)Receiving objects:  19% (168/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  20% (176/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  21% (185/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  22% (194/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  23% (203/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  24% (212/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  25% (220/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  26% (229/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  27% (238/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  28% (247/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  29% (256/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  30% (264/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  31% (273/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  32% (282/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  33% (291/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  34% (300/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  35% (308/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  36% (317/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  37% (326/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  38% (335/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  39% (344/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  40% (352/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  41% (361/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  42% (370/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  43% (379/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  44% (388/880), 35.34 MiB | 70.80 MiB/sReceiving objects:  44% (388/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  45% (396/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  46% (405/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  47% (414/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  48% (423/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  49% (432/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  50% (440/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  51% (449/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  52% (458/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  53% (467/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  54% (476/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  55% (484/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  56% (493/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  57% (502/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  58% (511/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  59% (520/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  60% (528/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  61% (537/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  62% (546/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  63% (555/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  64% (564/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  65% (572/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  66% (581/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  67% (590/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  68% (599/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  69% (608/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  70% (616/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  71% (625/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  72% (634/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  73% (643/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  74% (652/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  75% (660/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  76% (669/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  77% (678/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  78% (687/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  79% (696/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  80% (704/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  81% (713/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  82% (722/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  83% (731/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  84% (740/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  85% (748/880), 75.67 MiB | 75.73 MiB/sremote: Total 880 (delta 342), reused 310 (delta 279), pack-reused 429 (from 1)
Receiving objects:  86% (757/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  87% (766/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  88% (775/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  89% (784/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  90% (792/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  91% (801/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  92% (810/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  93% (819/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  94% (828/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  95% (836/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  96% (845/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  97% (854/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  98% (863/880), 75.67 MiB | 75.73 MiB/sReceiving objects:  99% (872/880), 75.67 MiB | 75.73 MiB/sReceiving objects: 100% (880/880), 75.67 MiB | 75.73 MiB/sReceiving objects: 100% (880/880), 91.91 MiB | 79.02 MiB/s, done.
Resolving deltas:   0% (0/457)Resolving deltas:   1% (5/457)Resolving deltas:   2% (10/457)Resolving deltas:   3% (14/457)Resolving deltas:   4% (19/457)Resolving deltas:   5% (23/457)Resolving deltas:   6% (28/457)Resolving deltas:   7% (32/457)Resolving deltas:   8% (37/457)Resolving deltas:   9% (42/457)Resolving deltas:  10% (46/457)Resolving deltas:  11% (51/457)Resolving deltas:  12% (55/457)Resolving deltas:  13% (60/457)Resolving deltas:  14% (64/457)Resolving deltas:  15% (69/457)Resolving deltas:  16% (74/457)Resolving deltas:  17% (78/457)Resolving deltas:  18% (83/457)Resolving deltas:  19% (87/457)Resolving deltas:  20% (92/457)Resolving deltas:  21% (96/457)Resolving deltas:  22% (101/457)Resolving deltas:  23% (106/457)Resolving deltas:  24% (110/457)Resolving deltas:  25% (115/457)Resolving deltas:  26% (119/457)Resolving deltas:  27% (124/457)Resolving deltas:  28% (129/457)Resolving deltas:  29% (133/457)Resolving deltas:  30% (138/457)Resolving deltas:  31% (142/457)Resolving deltas:  32% (147/457)Resolving deltas:  33% (151/457)Resolving deltas:  34% (156/457)Resolving deltas:  35% (161/457)Resolving deltas:  36% (165/457)Resolving deltas:  37% (170/457)Resolving deltas:  38% (174/457)Resolving deltas:  39% (179/457)Resolving deltas:  40% (183/457)Resolving deltas:  41% (188/457)Resolving deltas:  42% (192/457)Resolving deltas:  43% (197/457)Resolving deltas:  44% (202/457)Resolving deltas:  45% (206/457)Resolving deltas:  46% (212/457)Resolving deltas:  47% (215/457)Resolving deltas:  48% (220/457)Resolving deltas:  49% (224/457)Resolving deltas:  50% (229/457)Resolving deltas:  51% (234/457)Resolving deltas:  52% (238/457)Resolving deltas:  53% (243/457)Resolving deltas:  54% (247/457)Resolving deltas:  55% (252/457)Resolving deltas:  56% (256/457)Resolving deltas:  57% (261/457)Resolving deltas:  58% (266/457)Resolving deltas:  59% (270/457)Resolving deltas:  60% (275/457)Resolving deltas:  61% (279/457)Resolving deltas:  62% (285/457)Resolving deltas:  63% (288/457)Resolving deltas:  64% (293/457)Resolving deltas:  65% (298/457)Resolving deltas:  66% (302/457)Resolving deltas:  67% (307/457)Resolving deltas:  68% (311/457)Resolving deltas:  69% (316/457)Resolving deltas:  70% (320/457)Resolving deltas:  71% (325/457)Resolving deltas:  72% (330/457)Resolving deltas:  73% (334/457)Resolving deltas:  74% (339/457)Resolving deltas:  75% (343/457)Resolving deltas:  76% (348/457)Resolving deltas:  77% (352/457)Resolving deltas:  78% (357/457)Resolving deltas:  79% (362/457)Resolving deltas:  80% (366/457)Resolving deltas:  81% (371/457)Resolving deltas:  82% (375/457)Resolving deltas:  83% (380/457)Resolving deltas:  84% (384/457)Resolving deltas:  85% (389/457)Resolving deltas:  86% (394/457)Resolving deltas:  87% (398/457)Resolving deltas:  88% (403/457)Resolving deltas:  89% (407/457)Resolving deltas:  90% (412/457)Resolving deltas:  91% (416/457)Resolving deltas:  92% (421/457)Resolving deltas:  93% (426/457)Resolving deltas:  94% (430/457)Resolving deltas:  95% (435/457)Resolving deltas:  96% (439/457)Resolving deltas:  97% (444/457)Resolving deltas:  98% (448/457)Resolving deltas:  99% (453/457)Resolving deltas: 100% (457/457)Resolving deltas: 100% (457/457), done.

Review the metadata of the cpg0016-jump dataset

import pandas as pd
jump_plates_metadata = pd.read_csv("datasets/metadata/plate.csv.gz")
jump_plates_metadata["Metadata_PlateType"].unique()
array(['COMPOUND_EMPTY', 'COMPOUND', 'DMSO', 'TARGET2', 'CRISPR', 'ORF',
       'TARGET1', 'POSCON8'], dtype=object)
jump_plates_metadata.groupby(["Metadata_Source", "Metadata_Batch"]).describe()
Metadata_Plate Metadata_PlateType
count unique top freq count unique top freq
Metadata_Source Metadata_Batch
source_1 Batch1_20221004 9 9 UL000109 1 9 2 COMPOUND 6
Batch2_20221006 7 7 UL001647 1 7 1 COMPOUND 7
Batch3_20221010 8 8 UL000087 1 8 1 COMPOUND 8
Batch4_20221012 8 8 UL000081 1 8 1 COMPOUND 8
Batch5_20221030 11 11 UL000561 1 11 2 COMPOUND 10
... ... ... ... ... ... ... ... ... ...
source_9 20210918-Run11 9 9 GR00004367 1 9 2 COMPOUND 8
20210918-Run12 8 8 GR00004377 1 8 1 COMPOUND 8
20211013-Run14 13 13 GR00003279 1 13 2 COMPOUND 12
20211102-Run15 11 11 GR00004391 1 11 2 COMPOUND 10
20211103-Run16 17 17 GR00004405 1 17 2 COMPOUND 16

149 rows × 8 columns

Subset the dataset to extract samples with CRISPR, ORF, and NONE/DMSO treatments

crispr_wells_metadata = pd.read_csv("JUMP-Target/JUMP-Target-1_crispr_platemap.tsv", sep="\t")
crispr_wells_metadata["Plate_type"] = "CRISPR"
crispr_wells_metadata["Plate_label"] = 1
crispr_wells_metadata
well_position broad_sample Plate_type Plate_label
0 A01 BRDN0001480888 CRISPR 1
1 A02 BRDN0001483495 CRISPR 1
2 A03 BRDN0001147364 CRISPR 1
3 A04 BRDN0001490272 CRISPR 1
4 A05 BRDN0001480510 CRISPR 1
... ... ... ... ...
379 P20 BRDN0001145303 CRISPR 1
380 P21 BRDN0001484228 CRISPR 1
381 P22 BRDN0001487618 CRISPR 1
382 P23 BRDN0001487864 CRISPR 1
383 P24 BRDN0000735603 CRISPR 1

384 rows × 4 columns

Subset the dataset to extract samples with CRISPR, ORF, and NONE/DMSO treatments

orf_wells_metadata = pd.read_csv("JUMP-Target/JUMP-Target-1_orf_platemap.tsv", sep="\t")
orf_wells_metadata["Plate_type"] = "ORF"
orf_wells_metadata["Plate_label"] = 2
orf_wells_metadata
well_position broad_sample Plate_type Plate_label
0 A01 ccsbBroad304_00900 ORF 2
1 A02 ccsbBroad304_07795 ORF 2
2 A03 ccsbBroad304_02826 ORF 2
3 A04 ccsbBroad304_01492 ORF 2
4 A05 ccsbBroad304_00691 ORF 2
... ... ... ... ...
379 P20 ccsbBroad304_00277 ORF 2
380 P21 ccsbBroad304_06464 ORF 2
381 P22 ccsbBroad304_00476 ORF 2
382 P23 ccsbBroad304_01649 ORF 2
383 P24 ccsbBroad304_03934 ORF 2

384 rows × 4 columns

Subset the dataset to extract samples with CRISPR, ORF, and NONE/DMSO treatments

compound_wells_metadata = pd.read_csv("JUMP-Target/JUMP-Target-1_compound_platemap.tsv", sep="\t")
compound_wells_metadata["Plate_type"] = "COMPOUND"
compound_wells_metadata["Plate_label"] = 3
compound_wells_metadata
well_position broad_sample solvent Plate_type Plate_label
0 A01 BRD-A86665761-001-01-1 DMSO COMPOUND 3
1 A02 NaN DMSO COMPOUND 3
2 A03 BRD-A22032524-074-09-9 DMSO COMPOUND 3
3 A04 BRD-A01078468-001-14-8 DMSO COMPOUND 3
4 A05 BRD-K48278478-001-01-2 DMSO COMPOUND 3
... ... ... ... ... ...
379 P20 BRD-K68982262-001-01-4 DMSO COMPOUND 3
380 P21 BRD-K24616672-003-20-1 DMSO COMPOUND 3
381 P22 BRD-A82396632-008-30-8 DMSO COMPOUND 3
382 P23 BRD-K61250553-003-30-6 DMSO COMPOUND 3
383 P24 BRD-K70358946-001-17-3 DMSO COMPOUND 3

384 rows × 5 columns

Subset the dataset to extract samples with CRISPR, ORF, and NONE/DMSO treatments

wells_metadata = pd.concat([compound_wells_metadata, orf_wells_metadata, crispr_wells_metadata])
wells_metadata.loc[wells_metadata["broad_sample"].isna(), "Plate_label"] = 0
wells_metadata
well_position broad_sample solvent Plate_type Plate_label
0 A01 BRD-A86665761-001-01-1 DMSO COMPOUND 3
1 A02 NaN DMSO COMPOUND 0
2 A03 BRD-A22032524-074-09-9 DMSO COMPOUND 3
3 A04 BRD-A01078468-001-14-8 DMSO COMPOUND 3
4 A05 BRD-K48278478-001-01-2 DMSO COMPOUND 3
... ... ... ... ... ...
379 P20 BRDN0001145303 NaN CRISPR 1
380 P21 BRDN0001484228 NaN CRISPR 1
381 P22 BRDN0001487618 NaN CRISPR 1
382 P23 BRDN0001487864 NaN CRISPR 1
383 P24 BRDN0000735603 NaN CRISPR 1

1152 rows × 5 columns

Get the URL of each assay plate from the S3 bucket

import s3fs

fs = s3fs.S3FileSystem(anon=True)

batch_names = {}
plate_paths = {}
source_names = {}
plate_types = {}

for _, src_row in jump_plates_metadata.groupby(["Metadata_Source", "Metadata_Batch"]).describe().iterrows():
    source_name, batch_name = src_row.name

    # Ignore 'source_8' since the naming of the images is not standard
    if source_name in ["source_8"]:
        continue

    plate_type = src_row["Metadata_PlateType"].top

    for plate_path in fs.ls(f"cellpainting-gallery/cpg0016-jump/{source_name}/images/{batch_name}/images/"):
        plate_path = plate_path.split("/")[-1]
        if not plate_path:
            continue

        plate_name = plate_path.split("__")[0]

        source_names[plate_name] = source_name
        batch_names[plate_name] = batch_name
        plate_types[plate_name] = plate_type
        plate_paths[plate_name] = plate_path

Get the URL of each assay plate from the S3 bucket

plate_maps = pd.DataFrame()

plate_maps["Plate_name"] = batch_names.keys()
plate_maps["Source_name"] = plate_maps["Plate_name"].map(source_names)
plate_maps["Batch_name"] = plate_maps["Plate_name"].map(batch_names)
plate_maps["Plate_type"] = plate_maps["Plate_name"].map(plate_types)
plate_maps["Plate_path"] = plate_maps["Plate_name"].map(plate_paths)

plate_maps
Plate_name Source_name Batch_name Plate_type Plate_path
0 UL000109 source_1 Batch1_20221004 COMPOUND UL000109__2022-10-05T06_35_06-Measurement1
1 UL001641 source_1 Batch1_20221004 COMPOUND UL001641__2022-10-04T23_16_28-Measurement1
2 UL001643 source_1 Batch1_20221004 COMPOUND UL001643__2022-10-04T18_52_42-Measurement2
3 UL001645 source_1 Batch1_20221004 COMPOUND UL001645__2022-10-05T00_44_11-Measurement1
4 UL001651 source_1 Batch1_20221004 COMPOUND UL001651__2022-10-04T20_20_52-Measurement1
... ... ... ... ... ...
2333 GR00004417 source_9 20211103-Run16 COMPOUND GR00004417
2334 GR00004418 source_9 20211103-Run16 COMPOUND GR00004418
2335 GR00004419 source_9 20211103-Run16 COMPOUND GR00004419
2336 GR00004420 source_9 20211103-Run16 COMPOUND GR00004420
2337 GR00004421 source_9 20211103-Run16 COMPOUND GR00004421

2338 rows × 5 columns

Subset the data frame to separate perturbation (CRISPR/ORF/NONE) plates from COMPOUND plates

comp_plate_maps = plate_maps.query("Plate_type=='COMPOUND'")
comp_plate_maps
Plate_name Source_name Batch_name Plate_type Plate_path
0 UL000109 source_1 Batch1_20221004 COMPOUND UL000109__2022-10-05T06_35_06-Measurement1
1 UL001641 source_1 Batch1_20221004 COMPOUND UL001641__2022-10-04T23_16_28-Measurement1
2 UL001643 source_1 Batch1_20221004 COMPOUND UL001643__2022-10-04T18_52_42-Measurement2
3 UL001645 source_1 Batch1_20221004 COMPOUND UL001645__2022-10-05T00_44_11-Measurement1
4 UL001651 source_1 Batch1_20221004 COMPOUND UL001651__2022-10-04T20_20_52-Measurement1
... ... ... ... ... ...
2333 GR00004417 source_9 20211103-Run16 COMPOUND GR00004417
2334 GR00004418 source_9 20211103-Run16 COMPOUND GR00004418
2335 GR00004419 source_9 20211103-Run16 COMPOUND GR00004419
2336 GR00004420 source_9 20211103-Run16 COMPOUND GR00004420
2337 GR00004421 source_9 20211103-Run16 COMPOUND GR00004421

1905 rows × 5 columns

pert_plate_maps = plate_maps[plate_maps["Plate_type"].isin(["CRISPR", "ORF", "DMSO"])]
pert_plate_maps
Plate_name Source_name Batch_name Plate_type Plate_path
142 Dest210628-161651 source_10 2021_06_28_U2OS_48_hr_run9 DMSO Dest210628-161651
143 Dest210628-162003 source_10 2021_06_28_U2OS_48_hr_run9 DMSO Dest210628-162003
457 CP-CC9-R1-01 source_13 20220914_Run1 CRISPR CP-CC9-R1-01
458 CP-CC9-R1-02 source_13 20220914_Run1 CRISPR CP-CC9-R1-02
459 CP-CC9-R1-03 source_13 20220914_Run1 CRISPR CP-CC9-R1-03
... ... ... ... ... ...
1591 BR00127145 source_4 2021_08_30_Batch13 ORF BR00127145__2021-09-22T04_01_46-Measurement1
1592 BR00127146 source_4 2021_08_30_Batch13 ORF BR00127146__2021-09-22T12_25_07-Measurement1
1593 BR00127147 source_4 2021_08_30_Batch13 ORF BR00127147__2021-09-18T10_27_12-Measurement1
1594 BR00127148 source_4 2021_08_30_Batch13 ORF BR00127148__2021-09-21T11_44_23-Measurement1
1595 BR00127149 source_4 2021_08_30_Batch13 ORF BR00127149__2021-09-18T02_10_04-Measurement1

433 rows × 5 columns

Split the perturbation plates into Training, Validation, and Test sets

We’ll separate the plates in each batch into the three sets to have batch-level effects in each of the sets

import random
import math
trn_plates = []
val_plates = []
tst_plates = []

trn_proportion = 0.7
val_proportion = 0.2
tst_proportion = 0.1

for batch_name in pert_plate_maps["Batch_name"].unique():
    plate_names = pert_plate_maps.query(f"Batch_name == '{batch_name}'")["Plate_name"].tolist()
    random.shuffle(plate_names)

    tst_plates_count = int(math.ceil(len(plate_names) * tst_proportion))
    val_plates_count = int(math.ceil(len(plate_names) * val_proportion))

    tst_plates += plate_names[:tst_plates_count]
    val_plates += plate_names[tst_plates_count:tst_plates_count + val_plates_count]
    trn_plates += plate_names[tst_plates_count + val_plates_count:]
print("Training set size:", len(trn_plates))
print("Validation set size:", len(val_plates))
print("Testing set size:", len(tst_plates))
Training set size: 283
Validation set size: 96
Testing set size: 54

Create a PyTorch Dataset to load images from S3 storage

Defining a custom PyTorch dataset allows us to access the image data from S3 storage, even if it is not in a standard format across the distinct sources inside the database. Moreover, it is completely iterative, so no additional storage is used as the images are analyzed on the fly.

# @title Definition of a Dataset class capable to pull images from S3 buckets
import random
import numpy as np
import string
import s3fs

from itertools import product

from PIL import Image
import tifffile

from torch.utils.data import IterableDataset, get_worker_info
def load_well(plate_metadata, well_row, well_col, field_id, channels, s3):
    # Get the label of the current well
    curr_well_image = []

    plate_path = "cellpainting-gallery/cpg0016-jump/" + plate_metadata["Source_name"] + "/images/" + plate_metadata["Batch_name"] + "/images/" + plate_metadata["Plate_path"]

    for channel_id in range(channels):
        if plate_metadata["Source_name"] in ["source_1", "source_3", "source_4", "source_9", "source_11", "source_15"]:
            image_suffix = f"Images/r{well_row + 1:02d}c{well_col + 1:02d}f{field_id + 1:02d}p01-ch{channel_id + 1}sk1fk1fl1.tiff"

        else:
            if plate_metadata["Source_name"] in ["source_2", "source_5"]:
                a_locs = [1, 2, 3, 4, 5]
            elif plate_metadata["Source_name"] in ["source_6", "source_10"]:
                a_locs = [1, 2, 2, 3, 1, 4]
            elif plate_metadata["Source_name"] in ["source_7", "source_13"]:
                a_locs = [1, 1, 2, 3, 4]

            image_suffix = f"{plate_metadata["Plate_name"]}_{string.ascii_uppercase[well_row]}{well_col + 1:02d}_T0001F{field_id + 1:03d}L01A{a_locs[channel_id]:02d}Z01C{channel_id + 1:02d}.tif"

        image_url = "s3://" + plate_path + "/" + image_suffix

        try:
            with s3.open(image_url, 'rb') as f:
                curr_image = tifffile.imread(f)

        except FileNotFoundError:
            print("Failed retrieving:", image_url)
            return None

        curr_image = curr_image.astype(np.float32)
        curr_image /= 2 ** 16 - 1

        curr_well_image.append(curr_image)

    curr_well_image = np.array(curr_well_image)

    return curr_well_image

Create a PyTorch Dataset to load images from S3 storage

class TiffS3Dataset(IterableDataset):
    """This dataset could have virtually infinite samples.
    """
    def __init__(self, plate_maps, wells_metadata, plate_names, well_rows=24, well_cols=16, fields=4, channels=5, shuffle=False):
        super(TiffS3Dataset).__init__()

        self._plate_maps = plate_maps
        self._wells_metadata = wells_metadata

        self._plate_names = plate_names
        self._well_rows = well_rows
        self._well_cols = well_cols
        self._fields = fields
        self._channels = channels

        self._shuffle = shuffle

        self._worker_sel = slice(0, len(plate_names) * self._well_rows * self._well_cols)
        self._worker_id = 0
        self._num_workers = 1

        self._s3 = None

    def __iter__(self):
        # Select the barcodes that correspond to this worker
        self._s3 = s3fs.S3FileSystem(anon=True)

        self._plate_names = self._plate_names[self._worker_sel]

        well_row_range = range(self._well_rows)
        well_col_range = range(self._well_cols)
        fields_range = range(self._fields)

        for plate_name, well_row, well_col, field_id in product(self._plate_names, well_row_range, well_col_range, fields_range):
            if self._shuffle:
                plate_name = random.choice(self._plate_names)
                well_row = random.randrange(self._well_rows)
                well_col = random.randrange(self._well_cols)
                field_id = random.randrange(self._fields)

            curr_plate_map = self._plate_maps.query(f"Plate_name == '{plate_name}'")

            curr_plate_metadata = curr_plate_map.to_dict(orient='records')[0]

            if not len(curr_plate_metadata):
                continue

            curr_image = load_well(curr_plate_metadata, well_row, well_col, field_id, self._channels, self._s3)

            if curr_image is None:
                continue

            curr_plate_metadata["Well_position"] = f"{string.ascii_uppercase[well_row]}{well_col + 1:02d}"

            curr_image = curr_image[:, :1080, :1080]
            _, h, w = curr_image.shape
            pad_h = 1080 - h
            pad_w = 1080 - w

            if pad_h or pad_w:
                curr_image = np.pad(curr_image, ((0, 0), (0, pad_h), (0, pad_w)))
            
            if curr_plate_metadata["Plate_type"] == "DMSO":
                curr_label = 0

            else:
                curr_label = self._wells_metadata.query(f"Plate_type=='{curr_plate_metadata["Plate_type"]}' & well_position=='{string.ascii_uppercase[well_row]}{well_col + 1:02d}'")["Plate_label"]

                if not len(curr_label):
                    continue

                curr_label = curr_label.item()

            yield curr_image, curr_label, curr_plate_metadata

        self._s3 = None

Create a PyTorch Dataset to load images from S3 storage

def dataset_worker_init_fn(worker_id):
    """ZarrDataset multithread workers initialization function.
    """
    worker_info = torch.utils.data.get_worker_info()
    w_sel = slice(worker_id, None, worker_info.num_workers)

    dataset_obj = worker_info.dataset

    # Reset the random number generators in each worker.
    torch_seed = torch.initial_seed()

    dataset_obj._worker_sel = w_sel
    dataset_obj._worker_id = worker_id
    dataset_obj._num_workers = worker_info.num_workers

Create the different datasets from the plates lists

training_ds = TiffS3Dataset(pert_plate_maps, wells_metadata, trn_plates, 16, 24, 9, 5, shuffle=True)
validation_ds = TiffS3Dataset(pert_plate_maps, wells_metadata, val_plates, 16, 24, 9, 5, shuffle=True)
testing_ds = TiffS3Dataset(pert_plate_maps, wells_metadata, tst_plates, 16, 24, 9, 5, shuffle=True)

Compute field-level morphological profiles from perturbation plates

We’ll use a pre-trained deep learning model for image recognition to extract morphological features at field-level.

This process is usually applied at cell level; however, we’ll analyze the data at field-level for simplicity.

These morphological profiles will be used subsequently to train a perturbation classifier.

Import a pre-trained model from torchvision

We’ll start with a pre-trained MobileNet model for feature extraction since it is lightweight and fast. This in terms of computation resources required to use this model.

In the literature, more complex models are used, such as Inception V3, DenseNet, or even Vision Transformers. However, these models require GPU acceleration to be efficiently applied.

from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
weights = MobileNet_V3_Small_Weights.DEFAULT
model = mobilenet_v3_small(weights=weights)

Modify the model’s architecture to convert it into a feature extraction function

import torch
org_avgpool = model.avgpool
model.avgpool = torch.nn.Identity()
model.classifier = torch.nn.Identity()

Modify the model’s architecture to convert it into a feature extraction function

if torch.cuda.is_available():
    model.cuda()

model.eval()
MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(72, 72, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=72, bias=False)
          (1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (3): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 88, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(88, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(88, 88, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=88, bias=False)
          (1): BatchNorm2d(88, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(88, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (4): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(96, 96, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=96, bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(96, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (5): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(240, 64, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(64, 240, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (6): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
          (1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(240, 64, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(64, 240, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (7): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
          (1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(120, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (8): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(144, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=144, bias=False)
          (1): BatchNorm2d(144, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(40, 144, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(144, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (9): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(288, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(288, 288, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=288, bias=False)
          (1): BatchNorm2d(288, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(288, 72, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(72, 288, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (10): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(576, 144, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(144, 576, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (11): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (2): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(576, 144, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(144, 576, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (12): Conv2dNormActivation(
      (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
  )
  (avgpool): Identity()
  (classifier): Identity()
)

Load the pre-processing transforms form the original model

We need to apply the same transforms to the images that we feed to the model to have the expected behavior.

model_transforms = weights.transforms()
model_transforms
ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

Create a PyTorch DataLoader

A DataLoader takes a Dataset (or IterableDataset) and serves mini-batches of samples that can be used for model training or evaluation. It manages the mini-batch collation, and if enabled, the multi-thread loading of data.

from torch.utils.data.dataloader import DataLoader
batch_size = 10

training_dl = DataLoader(training_ds, batch_size=batch_size, num_workers=2, worker_init_fn=dataset_worker_init_fn)

Execute the feature extraction with the deep learning model

from tqdm.auto import tqdm

features = []
targets = []

for i, (x, y, _) in tqdm(enumerate(training_dl)):
    b, c, h, w = x.shape
    x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))

    if torch.cuda.is_available():
        x_t = x_t.cuda()
    
    with torch.no_grad():
        x_out = model(x_t)
        x_out = x_out.detach().cpu().reshape(-1, c, 576, 7, 7).sum(dim=1)
        x_out = org_avgpool(x_out).detach().reshape(b, -1)

    features.append(x_out)
    targets.append(y)

    # This is for illustration purposes.
    # We'll load the pre-extracted features from Cloud Storage, so no need to generate it here.
    break

features = torch.cat(features, dim=0)
targets = torch.cat(targets, dim=0)

features.shape, targets.shape
(torch.Size([10, 576]), torch.Size([10]))

Execute the feature extraction with the deep learning model

val_features = []
val_targets = []

validation_dl = DataLoader(validation_ds, batch_size=batch_size, num_workers=2, worker_init_fn=dataset_worker_init_fn)

for i, (x, y, _) in tqdm(enumerate(validation_dl)):
    b, c, h, w = x.shape
    x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))

    if torch.cuda.is_available():
        x_t = x_t.cuda()
    
    with torch.no_grad():
        x_out = model(x_t)
        x_out = x_out.detach().reshape(-1, c, 576, 7, 7).sum(dim=1)
        x_out = org_avgpool(x_out).detach().reshape(b, -1)

    val_features.append(x_out)
    val_targets.append(y)

    break

val_features = torch.cat(val_features, dim=0)
val_targets = torch.cat(val_targets, dim=0)

Execute the feature extraction with the deep learning model

tst_features = []
tst_targets = []

testing_dl = DataLoader(testing_ds, batch_size=batch_size, num_workers=2, worker_init_fn=dataset_worker_init_fn)

for i, (x, y, _) in tqdm(enumerate(testing_dl)):
    b, c, h, w = x.shape
    x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))

    if torch.cuda.is_available():
        x_t = x_t.cuda()
    
    with torch.no_grad():
        x_out = model(x_t)
        x_out = x_out.detach().reshape(-1, c, 576, 7, 7).sum(dim=1)
        x_out = org_avgpool(x_out).detach().reshape(b, -1)

    tst_features.append(x_out)
    tst_targets.append(y)

    break

tst_features = torch.cat(tst_features, dim=0)
tst_targets = torch.cat(tst_targets, dim=0)

Train a perturbation classifier model

We’ll take the pre-extracted features and train a perturbation classifier.

This approach will use a Multilayer Perceptron (MLP) model to classify the field-level profiles into three categories: NONE/DMSO = 0, CRISPR = 1, and ORF = 2.

The model will be fitted using Adam optimizer, which objective is to reduce the Cross Entropy loss between the predicted and the ground-truth category of each field.

Load the pre-extracted features to train the classifier

class GCPStorageDataset(IterableDataset):
    """This dataset loads the features from Cloud Storage
    """
    def __init__(self, features_url, reducer=None, shuffle=False):
        super(GCPStorageDataset).__init__()

        self._features_url = features_url
        self._features_dict = None
        self._reducer = reducer

        self._shuffle = shuffle

        self._worker_sel = slice(0, len(self._features_url))
        self._worker_id = 0
        self._num_workers = 1

    def __iter__(self):
        # Select the barcodes that correspond to this worker
        self._features_url = self._features_url[self._worker_sel]

        if self._shuffle:
            random.shuffle(self._features_url)

        for url in self._features_url:
            features_dict = torch.load(url)
            
            if self._reducer is not None:
                embeddings = reducer.transform(features_dict["features"])

            curr_n_samples = len(features_dict["features"])

            for index in range(curr_n_samples):
                if self._shuffle:
                    index = random.randrange(curr_n_samples)

                feats = features_dict["features"][index]
                target = features_dict["targets"][index]

                if self._reducer is not None:
                    reduced_feats = embeddings[index]
                else:
                    reduced_feats = None

                yield feats, target, reduced_feats

Inspect the distribution of the feature space

import umap
reducer = umap.UMAP()

Inspect the distribution of the feature space

trn_features_ds = GCPStorageDataset(["trn_features_002.pt"], shuffle=False)

trn_features, trn_targets, _ = list(zip(*trn_features_ds))

trn_targets = torch.tensor(trn_targets)

reducer.fit(trn_features)

trn_embeddings = reducer.transform(trn_features)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-828861056.py in <cell line: 0>()
      1 trn_features_ds = GCPStorageDataset(["trn_features_002.pt"], shuffle=False)
      2 
----> 3 trn_features, trn_targets, _ = list(zip(*trn_features_ds))
      4 
      5 trn_targets = torch.tensor(trn_targets)

/tmp/ipython-input-2874583140.py in __iter__(self)
     23 
     24         for url in self._features_url:
---> 25             features_dict = torch.load(url)
     26 
     27             if self._reducer is not None:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1482         pickle_load_args["encoding"] = "utf-8"
   1483 
-> 1484     with _open_file_like(f, "rb") as opened_file:
   1485         if _is_zipfile(opened_file):
   1486             # The zipfile reader is going to advance the current file position.

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in _open_file_like(name_or_buffer, mode)
    757 def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
    758     if _is_path(name_or_buffer):
--> 759         return _open_file(name_or_buffer, mode)
    760     else:
    761         if "w" in mode:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in __init__(self, name, mode)
    738 class _open_file(_opener[IO[bytes]]):
    739     def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None:
--> 740         super().__init__(open(name, mode))
    741 
    742     def __exit__(self, *args):

FileNotFoundError: [Errno 2] No such file or directory: 'trn_features_002.pt'

Inspect the distribution of the feature space

import matplotlib.pyplot as plt

class_names = ["NONE/DMSO", "CRISPR", "ORF", "COMPOUND"]
class_markers = ["*", "s", "o", "^"]
class_colors = ["black", "red", "blue", "green"]
class_facecolors = ["black", "none", "none", "none"]

for class_idx, class_name in enumerate(class_names):
    plt.scatter(trn_embeddings[trn_targets == class_idx, 0][::10], trn_embeddings[trn_targets == class_idx, 1][::10], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])

plt.legend()
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipython-input-614479785.py in <cell line: 0>()
      7 
      8 for class_idx, class_name in enumerate(class_names):
----> 9     plt.scatter(trn_embeddings[trn_targets == class_idx, 0][::10], trn_embeddings[trn_targets == class_idx, 1][::10], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])
     10 
     11 plt.legend()

NameError: name 'trn_embeddings' is not defined

Inspect the distribution of the feature space

val_features_ds = GCPStorageDataset(["val_features.pt"], shuffle=False)

val_features, val_targets, _ = list(zip(*val_features_ds))

val_targets = torch.tensor(val_targets)

val_embedding = reducer.transform(val_features)

val_embedding.shape

for class_idx, class_name in enumerate(class_names):
    plt.scatter(val_embedding[val_targets == class_idx, 0][::10], val_embedding[val_targets == class_idx, 1][::10], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])

plt.legend()
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-2608381846.py in <cell line: 0>()
      1 val_features_ds = GCPStorageDataset(["val_features.pt"], shuffle=False)
      2 
----> 3 val_features, val_targets, _ = list(zip(*val_features_ds))
      4 
      5 val_targets = torch.tensor(val_targets)

/tmp/ipython-input-2874583140.py in __iter__(self)
     23 
     24         for url in self._features_url:
---> 25             features_dict = torch.load(url)
     26 
     27             if self._reducer is not None:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1482         pickle_load_args["encoding"] = "utf-8"
   1483 
-> 1484     with _open_file_like(f, "rb") as opened_file:
   1485         if _is_zipfile(opened_file):
   1486             # The zipfile reader is going to advance the current file position.

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in _open_file_like(name_or_buffer, mode)
    757 def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
    758     if _is_path(name_or_buffer):
--> 759         return _open_file(name_or_buffer, mode)
    760     else:
    761         if "w" in mode:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in __init__(self, name, mode)
    738 class _open_file(_opener[IO[bytes]]):
    739     def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None:
--> 740         super().__init__(open(name, mode))
    741 
    742     def __exit__(self, *args):

FileNotFoundError: [Errno 2] No such file or directory: 'val_features.pt'

Create a classifier with a Multilayer Perceptron (MLP) architecture

The MobileNet model extracts \(576\) features per image, these will be the input features for the MLP classifer.

class PerturbationClassifier(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(PerturbationClassifier, self).__init__()

        self._reducer = torch.nn.Sequential(
            torch.nn.BatchNorm1d(num_features=num_features),
            torch.nn.Linear(in_features=num_features, out_features=2, bias=False),
        )

        self._classifier = torch.nn.Sequential(
            torch.nn.Dropout(0.1),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=2, out_features=num_classes, bias=False)
        )

    def forward(self, input):
        fx = self._reducer(input)
        
        y_pred = self._classifier(fx)

        return y_pred, fx

Create a classifier with a Multilayer Perceptron (MLP) architecture

classifier = PerturbationClassifier(576, 3)
if torch.cuda.is_available():
    classifier.cuda()

Create a classifier with a Multilayer Perceptron (MLP) architecture

optimizer = torch.optim.Adam([
    {'params': classifier._reducer.parameters(), 'lr': 1e-5, 'weight_decay': 0.001},
    {'params': classifier._classifier.parameters(), 'lr': 1e-4}
])
classifier_loss_fn = torch.nn.CrossEntropyLoss()
reducer_loss_fn = torch.nn.MSELoss()

Train the MLP model

trn_feat_ds = GCPStorageDataset([f"trn_features_{i:03d}.pt" for i in range(10)], reducer=reducer, shuffle=True)
val_feat_ds = GCPStorageDataset(["val_features.pt"], reducer=reducer, shuffle=False)
tst_feat_ds = GCPStorageDataset(["tst_features.pt"], reducer=reducer, shuffle=False)
trn_feat_dl = DataLoader(trn_feat_ds, batch_size=100, num_workers=2, worker_init_fn=dataset_worker_init_fn)
val_feat_dl = DataLoader(val_feat_ds, batch_size=100)
tst_feat_dl = DataLoader(tst_feat_ds, batch_size=100)

Train the MLP model

from torchmetrics.classification import Accuracy

n_dmso = 0
n_crispr = 0
n_orf = 0

# Training loop
classifier.train()

cls_loss_epoch = 0
red_loss_epoch = 0

trn_acc_metric = Accuracy(task="multiclass", num_classes=3)

for x, y, fx in tqdm(trn_feat_dl, total=1000):
    optimizer.zero_grad()

    if torch.cuda.is_available():
        x = x.cuda()

    y_pred, fx_pred = classifier(x)

    cls_loss = classifier_loss_fn(y_pred.cpu(), y)
    red_loss = reducer_loss_fn(fx_pred.cpu(), fx)

    cls_loss.backward(retain_graph=True)
    red_loss.backward()

    optimizer.step()

    cls_loss_epoch += cls_loss.item()
    red_loss_epoch += red_loss.item()

    trn_acc_metric(y_pred.cpu().softmax(dim=1), y)

    n_dmso += sum(y == 0)
    n_crispr += sum(y == 1)
    n_orf += sum(y == 2)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-4236613521.py in <cell line: 0>()
     13 trn_acc_metric = Accuracy(task="multiclass", num_classes=3)
     14 
---> 15 for x, y, fx in tqdm(trn_feat_dl, total=1000):
     16     optimizer.zero_grad()
     17 

/usr/local/lib/python3.12/dist-packages/tqdm/notebook.py in __iter__(self)
    248         try:
    249             it = super().__iter__()
--> 250             for obj in it:
    251                 # return super(tqdm...) will not catch exception
    252                 yield obj

/usr/local/lib/python3.12/dist-packages/tqdm/std.py in __iter__(self)
   1179 
   1180         try:
-> 1181             for obj in iterable:
   1182                 yield obj
   1183                 # Update and possibly print the progressbar.

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    732                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    733                 self._reset()  # type: ignore[call-arg]
--> 734             data = self._next_data()
    735             self._num_yielded += 1
    736             if (

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
   1514                 worker_id = self._task_info.pop(idx)[0]
   1515                 self._rcvd_idx += 1
-> 1516                 return self._process_data(data, worker_id)
   1517 
   1518     def _try_put_index(self):

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data, worker_idx)
   1549         self._try_put_index()
   1550         if isinstance(data, ExceptionWrapper):
-> 1551             data.reraise()
   1552         return data
   1553 

/usr/local/lib/python3.12/dist-packages/torch/_utils.py in reraise(self)
    767             # be constructed, don't try to instantiate since we don't know how to
    768             raise RuntimeError(msg) from None
--> 769         raise exception
    770 
    771 

FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
    data.append(next(self.dataset_iter))
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-2874583140.py", line 25, in __iter__
    features_dict = torch.load(url)
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 1484, in load
    with _open_file_like(f, "rb") as opened_file:
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 759, in _open_file_like
    return _open_file(name_or_buffer, mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 740, in __init__
    super().__init__(open(name, mode))
                     ^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'trn_features_006.pt'

Train the MLP model

n_total = n_dmso + n_crispr + n_orf
n_dmso / n_total, n_crispr / n_total, n_orf / n_total
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipython-input-351853166.py in <cell line: 0>()
      1 n_total = n_dmso + n_crispr + n_orf
----> 2 n_dmso / n_total, n_crispr / n_total, n_orf / n_total

ZeroDivisionError: division by zero
cls_loss_epoch / n_total
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipython-input-2159926283.py in <cell line: 0>()
----> 1 cls_loss_epoch / n_total

ZeroDivisionError: division by zero
red_loss_epoch / n_total
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipython-input-1461694865.py in <cell line: 0>()
----> 1 red_loss_epoch / n_total

ZeroDivisionError: division by zero
trn_acc_metric.compute()
tensor(0.)

Train the MLP model

n_dmso = 0
n_crispr = 0
n_orf = 0

cls_loss_epoch = 0
red_loss_epoch = 0

val_acc_metric = Accuracy(task="multiclass", num_classes=3)

for x_val, y_val, fx_val in tqdm(val_feat_dl):
    with torch.no_grad():
        if torch.cuda.is_available():
            x_val = x_val.cuda()

        y_val_pred, fx_val_pred = classifier(x_val)

        cls_loss = classifier_loss_fn(y_val_pred.cpu(), y_val)
        red_loss = reducer_loss_fn(fx_val_pred.cpu(), fx_val)

    cls_loss_epoch += cls_loss.item()
    red_loss_epoch += red_loss.item()
    
    val_acc_metric(y_val_pred.cpu().softmax(dim=1), y_val)

    n_dmso += sum(y_val == 0)
    n_crispr += sum(y_val == 1)
    n_orf += sum(y_val == 2)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-729096102.py in <cell line: 0>()
      8 val_acc_metric = Accuracy(task="multiclass", num_classes=3)
      9 
---> 10 for x_val, y_val, fx_val in tqdm(val_feat_dl):
     11     with torch.no_grad():
     12         if torch.cuda.is_available():

/usr/local/lib/python3.12/dist-packages/tqdm/notebook.py in __iter__(self)
    248         try:
    249             it = super().__iter__()
--> 250             for obj in it:
    251                 # return super(tqdm...) will not catch exception
    252                 yield obj

/usr/local/lib/python3.12/dist-packages/tqdm/std.py in __iter__(self)
   1179 
   1180         try:
-> 1181             for obj in iterable:
   1182                 yield obj
   1183                 # Update and possibly print the progressbar.

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    732                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    733                 self._reset()  # type: ignore[call-arg]
--> 734             data = self._next_data()
    735             self._num_yielded += 1
    736             if (

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    788     def _next_data(self):
    789         index = self._next_index()  # may raise StopIteration
--> 790         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    791         if self._pin_memory:
    792             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     31             for _ in possibly_batched_index:
     32                 try:
---> 33                     data.append(next(self.dataset_iter))
     34                 except StopIteration:
     35                     self.ended = True

/tmp/ipython-input-2874583140.py in __iter__(self)
     23 
     24         for url in self._features_url:
---> 25             features_dict = torch.load(url)
     26 
     27             if self._reducer is not None:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1482         pickle_load_args["encoding"] = "utf-8"
   1483 
-> 1484     with _open_file_like(f, "rb") as opened_file:
   1485         if _is_zipfile(opened_file):
   1486             # The zipfile reader is going to advance the current file position.

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in _open_file_like(name_or_buffer, mode)
    757 def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
    758     if _is_path(name_or_buffer):
--> 759         return _open_file(name_or_buffer, mode)
    760     else:
    761         if "w" in mode:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in __init__(self, name, mode)
    738 class _open_file(_opener[IO[bytes]]):
    739     def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None:
--> 740         super().__init__(open(name, mode))
    741 
    742     def __exit__(self, *args):

FileNotFoundError: [Errno 2] No such file or directory: 'val_features.pt'

Train the MLP model

n_total = n_dmso + n_crispr + n_orf
n_dmso / n_total, n_crispr / n_total, n_orf / n_total
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipython-input-351853166.py in <cell line: 0>()
      1 n_total = n_dmso + n_crispr + n_orf
----> 2 n_dmso / n_total, n_crispr / n_total, n_orf / n_total

ZeroDivisionError: division by zero
cls_loss_epoch / n_total
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipython-input-2159926283.py in <cell line: 0>()
----> 1 cls_loss_epoch / n_total

ZeroDivisionError: division by zero
red_loss_epoch / n_total
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipython-input-1461694865.py in <cell line: 0>()
----> 1 red_loss_epoch / n_total

ZeroDivisionError: division by zero
val_acc_metric.compute()
tensor(0.)

Wrap the training and validation steps for multiple epochs

Track the performance of the model throughout the epochs during training

avg_cls_loss_trn = []
avg_red_loss_trn = []
avg_acc_trn = []

avg_cls_loss_val = []
avg_red_loss_val = []
avg_acc_val = []

n_epochs = 20
q = tqdm(total=n_epochs)

for e in range(n_epochs):
    # Training loop
    classifier.train()

    loss_epoch = 0

    trn_acc_metric.reset()
    
    total_samples = 0
    for x, y, fx in trn_feat_dl:
        optimizer.zero_grad()

        if torch.cuda.is_available():
            x = x.cuda()

        y_pred, fx_pred = classifier(x)

        cls_loss = classifier_loss_fn(y_pred.cpu(), y)
        red_loss = reducer_loss_fn(fx_pred.cpu(), fx)

        cls_loss.backward(retain_graph=True)
        red_loss.backward()

        optimizer.step()

        cls_loss_epoch += cls_loss.item() * len(y)
        red_loss_epoch += red_loss.item() * len(y)

        trn_acc_metric(y_pred.cpu().softmax(dim=1), y)
        total_samples += len(y)

    avg_cls_loss_trn.append(cls_loss_epoch / total_samples)
    avg_red_loss_trn.append(red_loss_epoch / total_samples)

    avg_acc_trn.append(trn_acc_metric.compute())

    # Validation loop
    classifier.eval()

    cls_loss_epoch = 0
    red_loss_epoch = 0

    val_acc_metric.reset()

    total_samples = 0
    for x_val, y_val, fx_val in val_feat_dl:
        with torch.no_grad():
            if torch.cuda.is_available():
                x_val = x_val.cuda()

            y_val_pred, fx_val_pred = classifier(x_val)

            cls_loss = classifier_loss_fn(y_val_pred.cpu(), y_val)
            red_loss = reducer_loss_fn(fx_val_pred.cpu(), fx_val)

        cls_loss_epoch += cls_loss.item() * len(y_val)
        red_loss_epoch += red_loss.item() * len(y_val)

        val_acc_metric(y_val_pred.cpu().softmax(dim=1), y_val)
        total_samples += len(y_val)

    avg_cls_loss_val.append(cls_loss_epoch / total_samples)
    avg_red_loss_val.append(red_loss_epoch / total_samples)

    avg_acc_val.append(val_acc_metric.compute())

    q.set_description(f"Average training CE loss: {avg_cls_loss_trn[-1]:0.4f} / MSE loss: {avg_red_loss_trn[-1]:0.4f} (Accuracy: {100 * avg_acc_trn[-1]:0.2f} %). Average validation CE loss: {avg_cls_loss_val[-1]:04f} / MSE loss: {avg_red_loss_val[-1]:04f} (Accuracy: {100 * avg_acc_val[-1]:0.2f} %)")
    q.update()
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-4068602417.py in <cell line: 0>()
     19 
     20     total_samples = 0
---> 21     for x, y, fx in trn_feat_dl:
     22         optimizer.zero_grad()
     23 

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    732                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    733                 self._reset()  # type: ignore[call-arg]
--> 734             data = self._next_data()
    735             self._num_yielded += 1
    736             if (

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
   1514                 worker_id = self._task_info.pop(idx)[0]
   1515                 self._rcvd_idx += 1
-> 1516                 return self._process_data(data, worker_id)
   1517 
   1518     def _try_put_index(self):

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data, worker_idx)
   1549         self._try_put_index()
   1550         if isinstance(data, ExceptionWrapper):
-> 1551             data.reraise()
   1552         return data
   1553 

/usr/local/lib/python3.12/dist-packages/torch/_utils.py in reraise(self)
    767             # be constructed, don't try to instantiate since we don't know how to
    768             raise RuntimeError(msg) from None
--> 769         raise exception
    770 
    771 

FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
    data.append(next(self.dataset_iter))
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-2874583140.py", line 25, in __iter__
    features_dict = torch.load(url)
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 1484, in load
    with _open_file_like(f, "rb") as opened_file:
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 759, in _open_file_like
    return _open_file(name_or_buffer, mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 740, in __init__
    super().__init__(open(name, mode))
                     ^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'trn_features_002.pt'

Review the performance of the model throughout training

plt.plot(avg_cls_loss_trn, "k-", label="Training loss")
plt.plot(avg_cls_loss_val, "b:", label="Validation loss")
plt.legend()

plt.plot(avg_red_loss_trn, "k-", label="Training loss")
plt.plot(avg_red_loss_val, "b:", label="Validation loss")
plt.legend()

Review the performance of the model throughout training

plt.plot(avg_acc_trn, "k-", label="Training accuracy")
plt.plot(avg_acc_val, "b:", label="Validation accuracy")
plt.legend()

Evaluate the model with the witheld testing data

Save the classifier model to be used later or shared with collaborators

from torchmetrics.classification import ConfusionMatrix

classifier.eval()

n_dmso = 0
n_crispr = 0
n_orf = 0

cls_loss_epoch = 0
red_loss_epoch = 0

tst_acc_metric = Accuracy("multiclass", num_classes=3)
tst_confmat = ConfusionMatrix("multiclass", num_classes=3)

for x_tst, y_tst, fx_tst in tst_feat_dl:
    with torch.no_grad():
        if torch.cuda.is_available():
            x_tst = x_tst.cuda()

        y_tst_pred, fx_tst_pred = classifier(x_tst)
        cls_loss = classifier_loss_fn(y_tst_pred.cpu(), y_tst)
        red_loss = reducer_loss_fn(fx_tst_pred.cpu(), fx_tst)

    cls_loss_epoch += cls_loss.item() * len(y_tst)
    red_loss_epoch += red_loss.item() * len(y_tst)
    
    y_tst_prob = y_tst_pred.cpu().softmax(dim=1)
    tst_acc_metric.update(y_tst_prob, y_tst)
    tst_confmat.update(y_tst_prob, y_tst)

    n_dmso += sum(y_tst == 0)
    n_crispr += sum(y_tst == 1)
    n_orf += sum(y_tst == 2)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-3055367780.py in <cell line: 0>()
     13 tst_confmat = ConfusionMatrix("multiclass", num_classes=3)
     14 
---> 15 for x_tst, y_tst, fx_tst in tst_feat_dl:
     16     with torch.no_grad():
     17         if torch.cuda.is_available():

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    732                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    733                 self._reset()  # type: ignore[call-arg]
--> 734             data = self._next_data()
    735             self._num_yielded += 1
    736             if (

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    788     def _next_data(self):
    789         index = self._next_index()  # may raise StopIteration
--> 790         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    791         if self._pin_memory:
    792             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     31             for _ in possibly_batched_index:
     32                 try:
---> 33                     data.append(next(self.dataset_iter))
     34                 except StopIteration:
     35                     self.ended = True

/tmp/ipython-input-2874583140.py in __iter__(self)
     23 
     24         for url in self._features_url:
---> 25             features_dict = torch.load(url)
     26 
     27             if self._reducer is not None:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1482         pickle_load_args["encoding"] = "utf-8"
   1483 
-> 1484     with _open_file_like(f, "rb") as opened_file:
   1485         if _is_zipfile(opened_file):
   1486             # The zipfile reader is going to advance the current file position.

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in _open_file_like(name_or_buffer, mode)
    757 def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
    758     if _is_path(name_or_buffer):
--> 759         return _open_file(name_or_buffer, mode)
    760     else:
    761         if "w" in mode:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in __init__(self, name, mode)
    738 class _open_file(_opener[IO[bytes]]):
    739     def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None:
--> 740         super().__init__(open(name, mode))
    741 
    742     def __exit__(self, *args):

FileNotFoundError: [Errno 2] No such file or directory: 'tst_features.pt'

Save the classifier model to be used later or shared with collaborators

n_total = n_dmso + n_crispr + n_orf
n_dmso / n_total, n_crispr / n_total, n_orf / n_total
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipython-input-351853166.py in <cell line: 0>()
      1 n_total = n_dmso + n_crispr + n_orf
----> 2 n_dmso / n_total, n_crispr / n_total, n_orf / n_total

ZeroDivisionError: division by zero
cls_loss_epoch / n_total
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipython-input-2159926283.py in <cell line: 0>()
----> 1 cls_loss_epoch / n_total

ZeroDivisionError: division by zero
red_loss_epoch / n_total
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipython-input-1461694865.py in <cell line: 0>()
----> 1 red_loss_epoch / n_total

ZeroDivisionError: division by zero

Save the classifier model to be used later or shared with collaborators

tst_acc_metric.compute()
tst_confmat.compute()
tst_confmat.plot()
(<Figure size 640x480 with 1 Axes>,
 <Axes: xlabel='Predicted class', ylabel='True class'>)

Evaluate the capacity to mimic the UMap dimensionality reduction

trn_fx = []
trn_fx_pred = []
trn_y = []

for i, (x, y, fx) in enumerate(trn_feat_dl):
    if i >= 10:
        break

    with torch.no_grad():
        if torch.cuda.is_available():
            x = x.cuda()

        _, fx_pred = classifier(x)
        trn_fx_pred.append(fx_pred.detach().cpu())
        trn_fx.append(fx)
        trn_y.append(y)

trn_fx = torch.cat(trn_fx, dim=0)
trn_fx_pred = torch.cat(trn_fx_pred, dim=0)
trn_y = torch.cat(trn_y, dim=0)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-296939953.py in <cell line: 0>()
      3 trn_y = []
      4 
----> 5 for i, (x, y, fx) in enumerate(trn_feat_dl):
      6     if i >= 10:
      7         break

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    732                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    733                 self._reset()  # type: ignore[call-arg]
--> 734             data = self._next_data()
    735             self._num_yielded += 1
    736             if (

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
   1514                 worker_id = self._task_info.pop(idx)[0]
   1515                 self._rcvd_idx += 1
-> 1516                 return self._process_data(data, worker_id)
   1517 
   1518     def _try_put_index(self):

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data, worker_idx)
   1549         self._try_put_index()
   1550         if isinstance(data, ExceptionWrapper):
-> 1551             data.reraise()
   1552         return data
   1553 

/usr/local/lib/python3.12/dist-packages/torch/_utils.py in reraise(self)
    767             # be constructed, don't try to instantiate since we don't know how to
    768             raise RuntimeError(msg) from None
--> 769         raise exception
    770 
    771 

FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
    data.append(next(self.dataset_iter))
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-2874583140.py", line 25, in __iter__
    features_dict = torch.load(url)
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 1484, in load
    with _open_file_like(f, "rb") as opened_file:
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 759, in _open_file_like
    return _open_file(name_or_buffer, mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 740, in __init__
    super().__init__(open(name, mode))
                     ^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'trn_features_006.pt'

Evaluate the capacity to mimic the UMap dimensionality reduction

for class_idx, class_name in enumerate(class_names):
    plt.scatter(trn_fx[trn_y == class_idx, 0], trn_fx[trn_y == class_idx, 1], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])

plt.legend()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-4170264153.py in <cell line: 0>()
      1 for class_idx, class_name in enumerate(class_names):
----> 2     plt.scatter(trn_fx[trn_y == class_idx, 0], trn_fx[trn_y == class_idx, 1], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])
      3 
      4 plt.legend()

TypeError: list indices must be integers or slices, not tuple
for class_idx, class_name in enumerate(class_names):
    plt.scatter(trn_fx_pred[trn_y == class_idx, 0], trn_fx_pred[trn_y == class_idx, 1], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])

plt.legend()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-1926387892.py in <cell line: 0>()
      1 for class_idx, class_name in enumerate(class_names):
----> 2     plt.scatter(trn_fx_pred[trn_y == class_idx, 0], trn_fx_pred[trn_y == class_idx, 1], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])
      3 
      4 plt.legend()

TypeError: list indices must be integers or slices, not tuple

Evaluate the capacity to mimic the UMap dimensionality reduction

val_fx = []
val_fx_pred = []
val_y = []

for i, (x_val, y_val, fx_val) in enumerate(val_feat_dl):
    if i >= 10:
        break

    with torch.no_grad():
        if torch.cuda.is_available():
            x_val = x_val.cuda()

        _, fx_pred_val = classifier(x_val)
        val_fx_pred.append(fx_pred_val.detach().cpu())
        val_fx.append(fx_val)
        val_y.append(y_val)

val_fx = torch.cat(val_fx, dim=0)
val_fx_pred = torch.cat(val_fx_pred, dim=0)
val_y = torch.cat(val_y, dim=0)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/tmp/ipython-input-402407764.py in <cell line: 0>()
      3 val_y = []
      4 
----> 5 for i, (x_val, y_val, fx_val) in enumerate(val_feat_dl):
      6     if i >= 10:
      7         break

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    732                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    733                 self._reset()  # type: ignore[call-arg]
--> 734             data = self._next_data()
    735             self._num_yielded += 1
    736             if (

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    788     def _next_data(self):
    789         index = self._next_index()  # may raise StopIteration
--> 790         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    791         if self._pin_memory:
    792             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     31             for _ in possibly_batched_index:
     32                 try:
---> 33                     data.append(next(self.dataset_iter))
     34                 except StopIteration:
     35                     self.ended = True

/tmp/ipython-input-2874583140.py in __iter__(self)
     23 
     24         for url in self._features_url:
---> 25             features_dict = torch.load(url)
     26 
     27             if self._reducer is not None:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1482         pickle_load_args["encoding"] = "utf-8"
   1483 
-> 1484     with _open_file_like(f, "rb") as opened_file:
   1485         if _is_zipfile(opened_file):
   1486             # The zipfile reader is going to advance the current file position.

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in _open_file_like(name_or_buffer, mode)
    757 def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
    758     if _is_path(name_or_buffer):
--> 759         return _open_file(name_or_buffer, mode)
    760     else:
    761         if "w" in mode:

/usr/local/lib/python3.12/dist-packages/torch/serialization.py in __init__(self, name, mode)
    738 class _open_file(_opener[IO[bytes]]):
    739     def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None:
--> 740         super().__init__(open(name, mode))
    741 
    742     def __exit__(self, *args):

FileNotFoundError: [Errno 2] No such file or directory: 'val_features.pt'

Evaluate the capacity to mimic the UMap dimensionality reduction

for class_idx, class_name in enumerate(class_names):
    plt.scatter(val_fx[val_y == class_idx, 0], val_fx[val_y == class_idx, 1], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])

plt.legend()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-684093476.py in <cell line: 0>()
      1 for class_idx, class_name in enumerate(class_names):
----> 2     plt.scatter(val_fx[val_y == class_idx, 0], val_fx[val_y == class_idx, 1], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])
      3 
      4 plt.legend()

TypeError: list indices must be integers or slices, not tuple
for class_idx, class_name in enumerate(class_names):
    plt.scatter(val_fx_pred[val_y == class_idx, 0], val_fx_pred[val_y == class_idx, 1], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])

plt.legend()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-2487391935.py in <cell line: 0>()
      1 for class_idx, class_name in enumerate(class_names):
----> 2     plt.scatter(val_fx_pred[val_y == class_idx, 0], val_fx_pred[val_y == class_idx, 1], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])
      3 
      4 plt.legend()

TypeError: list indices must be integers or slices, not tuple

Use the pre-trained model to identify the behavior of compounds

Because the model has learned to recognize CRISPR, ORF, and NONE/DMSO effects, it can be used to determine the behavior of any treatment based on their morphological profile.

To do so, first extract the the morphological features using the pre-trained MobileNet, and then use the extracted features as input for the classifier model.

Execute the feature extraction and classification pipeline for compound data

compounds_ds = TiffS3Dataset(comp_plate_maps, wells_metadata, comp_plate_maps["Plate_name"].tolist(), 16, 24, 9, 5, shuffle=True)

batch_size = 5

compounds_dl = DataLoader(compounds_ds, batch_size=batch_size, num_workers=2, worker_init_fn=dataset_worker_init_fn)

metadata_list = []
for i, (x, y, metadata) in tqdm(enumerate(compounds_dl)):
    metadata_list.append(metadata)

    b, c, h, w = x.shape
    x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))

    if torch.cuda.is_available():
        x_t = x_t.cuda()

    with torch.no_grad():
        x_out = model(x_t)
        x_out = x_out.detach().reshape(-1, c, 576, 7, 7).sum(dim=1)
        x_out = org_avgpool(x_out).detach().reshape(b, -1)

        y_pred, fx_pred = classifier(x_out)

    break
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_2/images/20210823_Batch_10/images/1086292105/1086292105_L14_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_2/images/20210816_Batch_9/images/1086292396/1086292396_G01_T0001F008L01A01Z01C01.tif

Execute the feature extraction and classification pipeline for compound data

for class_idx, class_name in enumerate(class_names[:3]):
    plt.scatter(trn_fx_pred[trn_y == class_idx, 0][::10], trn_fx_pred[trn_y == class_idx, 1][::10], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])

plt.scatter(fx_pred.cpu()[:, 0], fx_pred.cpu()[:, 1], label=class_names[3], marker=class_markers[3], facecolors=class_facecolors[3], edgecolors=class_colors[3])

plt.legend()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-755015200.py in <cell line: 0>()
      1 for class_idx, class_name in enumerate(class_names[:3]):
----> 2     plt.scatter(trn_fx_pred[trn_y == class_idx, 0][::10], trn_fx_pred[trn_y == class_idx, 1][::10], label=class_names[class_idx], marker=class_markers[class_idx], facecolors=class_facecolors[class_idx], edgecolors=class_colors[class_idx])
      3 
      4 plt.scatter(fx_pred.cpu()[:, 0], fx_pred.cpu()[:, 1], label=class_names[3], marker=class_markers[3], facecolors=class_facecolors[3], edgecolors=class_colors[3])
      5 

TypeError: list indices must be integers or slices, not tuple

Execute the feature extraction and classification pipeline for compound data

metadata
{'Plate_name': ['GR00004377',
  'J12455d',
  'P01_ACPJUM062',
  '110000296341',
  '1086293409'],
 'Source_name': ['source_9', 'source_3', 'source_5', 'source_6', 'source_2'],
 'Batch_name': ['20210918-Run12',
  'CP_26_all_Phenix1',
  'JUMPCPE-20210716-Run12_20210719_162047',
  'p210920CPU2OS48hw384exp028JUMP',
  '20210726_Batch_7'],
 'Plate_type': ['COMPOUND', 'COMPOUND', 'COMPOUND', 'COMPOUND', 'COMPOUND'],
 'Plate_path': ['GR00004377',
  'J12455d__2021-09-24T16_30_28-Measurement1',
  'P01_ACPJUM062',
  '110000296341',
  '1086293409'],
 'Well_position': ['H02', 'P20', 'K14', 'G20', 'C01']}
y_pred.argmax(dim=1)
tensor([1, 1, 1, 1, 1])
wells_metadata.query("well_position == 'H13' & Plate_type=='COMPOUND'")
well_position broad_sample solvent Plate_type Plate_label
180 H13 NaN DMSO COMPOUND 0