-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Add linalg.lu_solve #72935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add linalg.lu_solve #72935
Conversation
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly updated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. Fixes #61657 [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
❌ 6 New FailuresAs of commit f927ccc (more details on the Dr. CI page): Expand to see more
🕵️ 6 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly updated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. Fixes #61657 ghstack-source-id: 86a91f0 Pull Request resolved: #72935
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly outdated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. ### Benchmarking <details> <summary> Benchmark Results (adjoint=False) </summary> ``` --------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27 shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27 shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27 shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27 shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27 shape torch.Size([128, 1, 1 A93C ]) | 96000 | 2054 | 28 | 242 | 96 | 27 shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27 shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27 shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27 shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27 shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28 shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27 shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27 shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28 shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27 shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28 shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28 shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28 shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28 shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28 shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28 shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28 shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28 shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28 shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28 shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28 shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28 shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28 shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28 shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28 shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28 shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28 shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28 shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28 shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28 shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28 shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30 shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31 shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31 shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35 shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36 shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43 shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82 shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122 shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34 shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52 shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65 shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65 shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67 shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73 shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96 shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208 shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306 shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42 shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83 shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136 shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138 shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143 shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152 shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177 shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228 shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502 shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770 shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72 shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139 shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278 shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286 shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330 shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360 shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408 shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543 shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310 shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122 ``` </details> <details> <summary> Benchmark Results (adjoint=True) </summary> ``` [----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27 shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27 shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27 shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27 shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27 shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28 shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27 shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28 shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28 shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27 shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27 shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27 shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28 shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27 shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28 shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28 shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28 shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28 shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28 shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28 shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28 shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28 shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28 shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28 shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28 shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469] shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28 shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28 shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28 shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28 shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28 shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28 shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28 shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29 shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32 shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32 shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33 shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35 shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36 shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43 shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83 shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120 shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34 shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59 shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66 shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68 shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69 shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74 shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97 shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210 shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308 shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46 shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92 shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139 shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142 shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143 shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155 shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180 shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231 shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519 shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794 shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78 shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150 shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284 shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288 shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330 shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367 shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414 shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553 shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410 shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277 Times are in microseconds (us). ``` </details> To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one) <details> <summary> Benchmarking script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare name = "heuristic" label = "lu_solve {}".format(name) shapes = [1, 2, 8, 16, 32, 64, 128, 256] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda") def f(LU, pivots, B, adjoint): P, L, U = torch.lu_unpack(LU, pivots) if adjoint: X = torch.linalg.solve_triangular(U.mH, B, upper=False) return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X) else: X = P.mT @ B X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X) return torch.linalg.solve_triangular(U, X, upper=True, out=X) for n, batch in itertools.product(shapes, batches): LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n))) B = make_arg(batch + (n, 1)) print(LU.shape) stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)" #stmt = "f(LU, pivots, B, adjoint=adjoint)" for adjoint in (True, False): timer = Timer(stmt, globals=globals(), label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""), description=label, sub_label=f"shape {LU.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}_lu_solve.pickle".format(name), 'wb') as f: pickle.dump(results, f) ``` </details> Finally, I joined all the results with the following script: <details> <summary> Script to join the results </summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "looped_magma", "looped cusolver", "batched cublas", "batched magma", "unpack+solve_triangular", "heuristic", ] timers = [] for name in files: with open("{}_lu_solve.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> ### Fix for Magma's batched lu_solve when `adjoint=True` I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future. <details> <summary> Fix for MAGMA's issue with `adjoint=True` </summary> ```cpp auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { if (trans == TransposeType::NoTranspose) { lu_solve_batched_magma(LU, pivots, B, trans); return; } // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU // The LU of the transpose is not the transpose of the LU // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U) auto diag = LU.diagonal(0, -2, -1); auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) + LU.triu(1).div_(diag.unsqueeze(-1)); LU_f.diagonal(0, -2, -1).copy_(diag); if (trans == TransposeType::ConjTranspose) { LU_f = LU_f.conj_physical(); } LU_f.transpose(-2, -1); // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous()); // Trivial permutation auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous(); lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose); // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP // Fill `perm` with the identity permutation (perhaps batched) // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous(); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1) .add_output(perm) .add_input(pivots) .build(); unpack_pivots_stub(pivots.device().type(), iter, m); B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone()); }; ``` </details> Fixes #61657 [ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly outdated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. ### Benchmarking <details> <summary> Benchmark Results (adjoint=False) </summary> ``` --------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27 shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27 shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27 shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27 shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27 shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27 shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27 shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27 shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27 shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28 shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27 shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27 shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28 shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27 shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28 shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28 shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28 shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28 shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28 shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28 shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28 shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28 shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28 shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28 shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28 shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28 shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28 shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28 shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28 shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28 shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28 shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28 shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28 shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28 shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28 shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30 shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31 shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31 shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35 shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36 shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43 shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82 shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122 shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34 shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52 shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65 shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65 shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67 shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73 shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96 shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208 shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306 shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42 shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83 shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136 shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138 shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143 shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152 shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177 shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228 shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502 shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770 shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72 shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139 shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278 shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286 shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330 shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360 shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408 shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543 shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310 shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122 ``` </details> <details> <summary> Benchmark Results (adjoint=True) </summary> ``` [----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27 shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27 shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27 shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27 shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27 shape torch.Size([512, 1, 1]) | DA90 383300 | 8100 | 28 | 78 | 110 | 28 shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27 shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28 shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28 shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27 shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27 shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27 shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28 shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27 shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28 shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28 shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28 shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28 shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28 shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28 shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28 shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28 shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28 shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28 shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28 shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469] shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28 shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28 shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28 shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28 shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28 shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28 shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28 shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29 shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32 shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32 shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33 shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35 shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36 shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43 shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83 shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120 shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34 shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59 shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66 shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68 shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69 shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74 shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97 shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210 shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308 shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46 shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92 shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139 shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142 shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143 shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155 shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180 shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231 shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519 shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794 shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78 shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150 shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284 shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288 shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330 shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367 shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414 shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553 shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410 shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277 Times are in microseconds (us). ``` </details> To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one) <details> <summary> Benchmarking script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare name = "heuristic" label = "lu_solve {}".format(name) shapes = [1, 2, 8, 16, 32, 64, 128, 256] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda") def f(LU, pivots, B, adjoint): P, L, U = torch.lu_unpack(LU, pivots) if adjoint: X = torch.linalg.solve_triangular(U.mH, B, upper=False) return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X) else: X = P.mT @ B X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X) return torch.linalg.solve_triangular(U, X, upper=True, out=X) for n, batch in itertools.product(shapes, batches): LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n))) B = make_arg(batch + (n, 1)) print(LU.shape) stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)" #stmt = "f(LU, pivots, B, adjoint=adjoint)" for adjoint in (True, False): timer = Timer(stmt, globals=globals(), label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""), description=label, sub_label=f"shape {LU.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}_lu_solve.pickle".format(name), 'wb') as f: pickle.dump(results, f) ``` </details> Finally, I joined all the results with the following script: <details> <summary> Script to join the results </summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "looped_magma", "looped cusolver", "batched cublas", "batched magma", "unpack+solve_triangular", "heuristic", ] timers = [] for name in files: with open("{}_lu_solve.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> ### Fix for Magma's batched lu_solve when `adjoint=True` I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future. <details> <summary> Fix for MAGMA's issue with `adjoint=True` </summary> ```cpp auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { if (trans == TransposeType::NoTranspose) { lu_solve_batched_magma(LU, pivots, B, trans); return; } // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU // The LU of the transpose is not the transpose of the LU // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U) auto diag = LU.diagonal(0, -2, -1); auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) + LU.triu(1).div_(diag.unsqueeze(-1)); LU_f.diagonal(0, -2, -1).copy_(diag); if (trans == TransposeType::ConjTranspose) { LU_f = LU_f.conj_physical(); } LU_f.transpose(-2, -1); // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous()); // Trivial permutation auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous(); lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose); // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP // Fill `perm` with the identity permutation (perhaps batched) // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous(); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1) .add_output(perm) .add_input(pivots) .build(); unpack_pivots_stub(pivots.device().type(), iter, m); B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone()); }; ``` </details> Fixes #61657 [ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly updated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. Fixes #61657 ghstack-source-id: 7076c75 Pull Request resolved: #72935
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly outdated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. ### Benchmarking <details> <summary> Benchmark Results (adjoint=False) </summary> ``` --------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27 shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27 shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27 shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27 shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27 shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27 shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27 shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27 shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27 shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28 shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27 shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27 shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28 shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27 shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28 shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28 shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28 shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28 shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28 shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28 shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28 shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28 shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28 shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28 shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28 shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28 shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28 shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28 shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28 shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28 shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28 shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28 shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28 shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28 shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28 shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30 shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31 shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31 shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35 shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36 shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43 shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82 shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122 shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34 shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52 shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65 shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65 shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67 shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73 shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96 shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208 shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306 shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42 shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83 shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136 shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138 shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143 shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152 shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177 shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228 shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502 shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770 shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72 shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139 shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278 shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286 shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330 shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360 shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408 shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543 shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310 shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122 ``` </details> <details> <summary> Benchmark Results (adjoint=True) </summary> ``` [----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27 shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27 shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27 shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27 shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27 shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28 shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27 shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28 shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28 shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27 shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27 shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27 shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28 shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27 shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28 shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28 shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28 shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28 shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28 shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28 shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28 shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28 shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28 shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28 shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28 shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469] shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28 shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28 shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28 shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28 shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28 shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28 shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28 shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29 shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32 shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32 shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33 shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35 shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36 shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43 shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83 shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120 shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34 shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59 shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66 shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68 shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69 shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74 shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97 shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210 shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308 shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46 shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92 shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139 shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142 shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143 shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155 shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180 shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231 shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519 shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794 shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78 shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150 shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284 shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288 shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330 shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367 shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414 shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553 shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410 shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277 Times are in microseconds (us). ``` </details> To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one) <details> <summary> Benchmarking script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare name = "heuristic" label = "lu_solve {}".format(name) shapes = [1, 2, 8, 16, 32, 64, 128, 256] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda") def f(LU, pivots, B, adjoint): P, L, U = torch.lu_unpack(LU, pivots) if adjoint: X = torch.linalg.solve_triangular(U.mH, B, upper=False) return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X) else: X = P.mT @ B X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X) return torch.linalg.solve_triangular(U, X, upper=True, out=X) for n, batch in itertools.product(shapes, batches): LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n))) B = make_arg(batch + (n, 1)) print(LU.shape) stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)" #stmt = "f(LU, pivots, B, adjoint=adjoint)" for adjoint in (True, False): timer = Timer(stmt, globals=globals(), label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""), description=label, sub_label=f"shape {LU.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}_lu_solve.pickle".format(name), 'wb') as f: pickle.dump(results, f) ``` </details> Finally, I joined all the results with the following script: <details> <summary> Script to join the results </summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "looped_magma", "looped cusolver", "batched cublas", "batched magma", "unpack+solve_triangular", "heuristic", ] timers = [] for name in files: with open("{}_lu_solve.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> ### Fix for Magma's batched lu_solve when `adjoint=True` I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future. <details> <summary> Fix for MAGMA's issue with `adjoint=True` </summary> ```cpp auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { if (trans == TransposeType::NoTranspose) { lu_solve_batched_magma(LU, pivots, B, trans); return; } // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU // The LU of the transpose is not the transpose of the LU // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U) auto diag = LU.diagonal(0, -2, -1); auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) + LU.triu(1).div_(diag.unsqueeze(-1)); LU_f.diagonal(0, -2, -1).copy_(diag); if (trans == TransposeType::ConjTranspose) { LU_f = LU_f.conj_physical(); } LU_f.transpose(-2, -1); // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous()); // Trivial permutation auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous(); lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose); // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP // Fill `perm` with the identity permutation (perhaps batched) // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous(); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1) .add_output(perm) .add_input(pivots) .build(); unpack_pivots_stub(pivots.device().type(), iter, m); B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone()); }; ``` </details> Fixes #61657 [ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly outdated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. ### Benchmarking <details> <summary> Benchmark Results (adjoint=False) </summary> ``` --------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27 shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27 shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27 shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27 shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27 shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27 shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27 shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27 shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27 shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28 shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27 shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27 shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28 shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27 shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28 shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28 shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28 shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28 shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28 shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28 shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28 shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28 shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28 shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28 shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28 shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28 shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28 shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28 shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28 shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28 shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28 shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28 shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28 shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28 shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28 shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30 shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31 shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31 shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35 shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36 shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43 shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82 shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122 shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34 shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52 shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65 shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65 shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67 shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73 shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96 shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208 shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306 shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42 shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83 shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136 shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138 shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143 shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152 shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177 shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228 shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502 shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770 shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72 shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139 shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278 shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286 shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330 shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360 shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408 shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543 shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310 shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122 ``` </details> <details> <summary> Benchmark Results (adjoint=True) </summary> ``` [----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27 shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27 shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27 shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27 shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27 shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28 shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27 shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28 shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28 shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27 shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27 shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27 shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28 shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27 shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28 shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28 shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28 shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28 shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28 shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 10000 107 | 28 shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28 shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28 shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28 shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28 shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28 shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469] shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28 shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28 shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28 shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28 shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28 shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28 shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28 shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29 shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32 shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32 shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33 shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35 shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36 shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43 shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83 shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120 shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34 shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59 shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66 shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68 shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69 shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74 shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97 shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210 shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308 shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46 shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92 shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139 shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142 shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143 shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155 shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180 shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231 shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519 shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794 shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78 shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150 shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284 shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288 shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330 shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367 shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414 shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553 shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410 shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277 Times are in microseconds (us). ``` </details> To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one) <details> <summary> Benchmarking script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare name = "heuristic" label = "lu_solve {}".format(name) shapes = [1, 2, 8, 16, 32, 64, 128, 256] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda") def f(LU, pivots, B, adjoint): P, L, U = torch.lu_unpack(LU, pivots) if adjoint: X = torch.linalg.solve_triangular(U.mH, B, upper=False) return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X) else: X = P.mT @ B X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X) return torch.linalg.solve_triangular(U, X, upper=True, out=X) for n, batch in itertools.product(shapes, batches): LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n))) B = make_arg(batch + (n, 1)) print(LU.shape) stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)" #stmt = "f(LU, pivots, B, adjoint=adjoint)" for adjoint in (True, False): timer = Timer(stmt, globals=globals(), label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""), description=label, sub_label=f"shape {LU.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}_lu_solve.pickle".format(name), 'wb') as f: pickle.dump(results, f) ``` </details> Finally, I joined all the results with the following script: <details> <summary> Script to join the results </summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "looped_magma", "looped cusolver", "batched cublas", "batched magma", "unpack+solve_triangular", "heuristic", ] timers = [] for name in files: with open("{}_lu_solve.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> ### Fix for Magma's batched lu_solve when `adjoint=True` I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future. <details> <summary> Fix for MAGMA's issue with `adjoint=True` </summary> ```cpp auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { if (trans == TransposeType::NoTranspose) { lu_solve_batched_magma(LU, pivots, B, trans); return; } // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU // The LU of the transpose is not the transpose of the LU // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U) auto diag = LU.diagonal(0, -2, -1); auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) + LU.triu(1).div_(diag.unsqueeze(-1)); LU_f.diagonal(0, -2, -1).copy_(diag); if (trans == TransposeType::ConjTranspose) { LU_f = LU_f.conj_physical(); } LU_f.transpose(-2, -1); // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous()); // Trivial permutation auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous(); lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose); // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP // Fill `perm` with the identity permutation (perhaps batched) // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous(); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1) .add_output(perm) .add_input(pivots) .build(); unpack_pivots_stub(pivots.device().type(), iter, m); B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone()); }; ``` </details> Fixes #61657 [ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly outdated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. ### Benchmarking <details> <summary> Benchmark Results (adjoint=False) </summary> ``` --------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27 shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27 shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 82B6 94 | 27 shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27 shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27 shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27 shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27 shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27 shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27 shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28 shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27 shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27 shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28 shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27 shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28 shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28 shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28 shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28 shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28 shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28 shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28 shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28 shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28 shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28 shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28 shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28 shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28 shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28 shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28 shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28 shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28 shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28 shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28 shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28 shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28 shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30 shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31 shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31 shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35 shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36 shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43 shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82 shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122 shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34 shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52 shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65 shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65 shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67 shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73 shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96 shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208 shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306 shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42 shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83 shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136 shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138 shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143 shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152 shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177 shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228 shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502 shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770 shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72 shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139 shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278 shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286 shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330 shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360 shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408 shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543 shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310 shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122 ``` </details> <details> <summary> Benchmark Results (adjoint=True) </summary> ``` [----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27 shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27 shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27 shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27 shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27 shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28 shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27 shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28 shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28 shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27 shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27 shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27 shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28 shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27 shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28 shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28 shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28 shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28 shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28 shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28 shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28 shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28 shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28 shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28 shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28 shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469] shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28 shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28 shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28 shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28 shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28 shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28 shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28 shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29 shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32 shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32 shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33 shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35 shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36 shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43 shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83 shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120 shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34 shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59 shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66 shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68 shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69 shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74 shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97 shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210 shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308 shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46 shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92 shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139 shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142 shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143 shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155 shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180 shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231 shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519 shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794 shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78 shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150 shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284 shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288 shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330 shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367 shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414 shape torch.Size([128, 256, 256]) | 96000 8912 | 9900 | 552 | 1330 | 1579 | 553 shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410 shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277 Times are in microseconds (us). ``` </details> To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one) <details> <summary> Benchmarking script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare name = "heuristic" label = "lu_solve {}".format(name) shapes = [1, 2, 8, 16, 32, 64, 128, 256] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda") def f(LU, pivots, B, adjoint): P, L, U = torch.lu_unpack(LU, pivots) if adjoint: X = torch.linalg.solve_triangular(U.mH, B, upper=False) return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X) else: X = P.mT @ B X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X) return torch.linalg.solve_triangular(U, X, upper=True, out=X) for n, batch in itertools.product(shapes, batches): LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n))) B = make_arg(batch + (n, 1)) print(LU.shape) stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)" #stmt = "f(LU, pivots, B, adjoint=adjoint)" for adjoint in (True, False): timer = Timer(stmt, globals=globals(), label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""), description=label, sub_label=f"shape {LU.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}_lu_solve.pickle".format(name), 'wb') as f: pickle.dump(results, f) ``` </details> Finally, I joined all the results with the following script: <details> <summary> Script to join the results </summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "looped_magma", "looped cusolver", "batched cublas", "batched magma", "unpack+solve_triangular", "heuristic", ] timers = [] for name in files: with open("{}_lu_solve.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> ### Fix for Magma's batched lu_solve when `adjoint=True` I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future. <details> <summary> Fix for MAGMA's issue with `adjoint=True` </summary> ```cpp auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { if (trans == TransposeType::NoTranspose) { lu_solve_batched_magma(LU, pivots, B, trans); return; } // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU // The LU of the transpose is not the transpose of the LU // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U) auto diag = LU.diagonal(0, -2, -1); auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) + LU.triu(1).div_(diag.unsqueeze(-1)); LU_f.diagonal(0, -2, -1).copy_(diag); if (trans == TransposeType::ConjTranspose) { LU_f = LU_f.conj_physical(); } LU_f.transpose(-2, -1); // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous()); // Trivial permutation auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous(); lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose); // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP // Fill `perm` with the identity permutation (perhaps batched) // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous(); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1) .add_output(perm) .add_input(pivots) .build(); unpack_pivots_stub(pivots.device().type(), iter, m); B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone()); }; ``` </details> Fixes #61657 [ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly outdated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. ### Benchmarking <details> <summary> Benchmark Results (adjoint=False) </summary> ``` --------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27 shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27 shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27 shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27 shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27 shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27 shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27 shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27 shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27 shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28 shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27 shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27 shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28 shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27 shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28 shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28 shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28 shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28 shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28 shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28 shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28 shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28 shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28 shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28 shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28 shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28 shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28 shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28 shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28 shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28 shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28 shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28 shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28 shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28 shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28 shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30 shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31 shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31 shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35 shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36 shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43 shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82 shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122 shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34 shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52 shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65 shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65 shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67 shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73 shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96 shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208 shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306 shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42 shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83 shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136 shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138 shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143 shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152 shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177 shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228 shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502 shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770 shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72 shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139 shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278 shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286 shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330 shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360 shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408 shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543 shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310 shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122 ``` </details> <details> <summary> Benchmark Results (adjoint=True) </summary> ``` [----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27 shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27 shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27 shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27 shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27 shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28 shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27 shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28 shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28 shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27 shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27 shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27 shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28 shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27 shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28 shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28 shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28 shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 50 | 107 | 28 shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28 shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28 shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28 shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28 shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28 shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28 shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28 shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469] shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28 shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28 shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28 shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28 shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28 shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28 shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28 shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29 shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32 shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32 shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33 shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35 shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36 shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43 shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83 shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120 shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34 shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59 shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66 shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68 shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69 shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74 shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97 shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210 shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308 shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46 shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92 shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139 shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142 shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143 shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155 shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180 shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231 shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519 shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794 shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78 shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150 shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284 shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288 shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330 shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367 shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414 shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553 shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410 shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277 Times are in microseconds (us). ``` </details> To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one) <details> <summary> Benchmarking script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare name = "heuristic" label = "lu_solve {}".format(name) shapes = [1, 2, 8, 16, 32, 64, 128, 256] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda") def f(LU, pivots, B, adjoint): P, L, U = torch.lu_unpack(LU, pivots) if adjoint: X = torch.linalg.solve_triangular(U.mH, B, upper=False) return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X) else: X = P.mT @ B X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X) return torch.linalg.solve_triangular(U, X, upper=True, out=X) for n, batch in itertools.product(shapes, batches): LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n))) B = make_arg(batch + (n, 1)) print(LU.shape) stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)" #stmt = "f(LU, pivots, B, adjoint=adjoint)" for adjoint in (True, False): timer = Timer(stmt, globals=globals(), label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""), description=label, sub_label=f"shape {LU.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}_lu_solve.pickle".format(name), 'wb') as f: pickle.dump(results, f) ``` </details> Finally, I joined all the results with the following script: <details> <summary> Script to join the results </summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "looped_magma", "looped cusolver", "batched cublas", "batched magma", "unpack+solve_triangular", "heuristic", ] timers = [] for name in files: with open("{}_lu_solve.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> ### Fix for Magma's batched lu_solve when `adjoint=True` I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future. <details> <summary> Fix for MAGMA's issue with `adjoint=True` </summary> ```cpp auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { if (trans == TransposeType::NoTranspose) { lu_solve_batched_magma(LU, pivots, B, trans); return; } // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU // The LU of the transpose is not the transpose of the LU // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U) auto diag = LU.diagonal(0, -2, -1); auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) + LU.triu(1).div_(diag.unsqueeze(-1)); LU_f.diagonal(0, -2, -1).copy_(diag); if (trans == TransposeType::ConjTranspose) { LU_f = LU_f.conj_physical(); } LU_f.transpose(-2, -1); // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous()); // Trivial permutation auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous(); lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose); // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP // Fill `perm` with the identity permutation (perhaps batched) // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous(); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1) .add_output(perm) .add_input(pivots) .build(); unpack_pivots_stub(pivots.device().type(), iter, m); B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone()); }; ``` </details> Fixes #61657 [ghstack-poisoned]
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA when calling the batched MAGMA backend with trans=True. We work around that by solving the system solving two triangular systems. We also update the heuristics for this function, as they were fairly outdated. We found that cuSolver is king, so luckily we do not need to rely on the buggy backend from magma for this function. We added tests testing this function left and right. We also added tests for the different backends. We also activated the tests for AMD, as those should work as well. ### Benchmarking <details> <summary> Benchmark Results (adjoint=False) </summary> ``` --------------------------------------------------------------------------------------------- linalg.lu_solve CUDA ---------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 750 | 34 | 28 | 252 | 86 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 239 | 94 | 27 shape torch.Size([4, 1, 1]) | 2995 | 83 | 28 | 239 | 100 | 27 shape torch.Size([8, 1, 1]) | 6000 | 146 | 28 | 239 | 94 | 27 shape torch.Size([16, 1, 1]) | 11900 | 272 | 28 | 241 | 95 | 27 shape torch.Size([32, 1, 1]) | 23880 | 524 | 28 | 244 | 94 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 245 | 99 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2054 | 28 | 242 | 96 | 27 shape torch.Size([512, 1, 1]) | 381900 | 8100 | 28 | 250 | 94 | 27 shape torch.Size([1024, 1, 1]) | 763800 | 16200 | 28 | 257 | 95 | 27 shape torch.Size([1, 2, 2]) | 750 | 33 | 28 | 240 | 88 | 27 shape torch.Size([2, 2, 2]) | 1500 | 51 | 28 | 240 | 96 | 27 shape torch.Size([4, 2, 2]) | 2991 | 82 | 28 | 241 | 96 | 28 shape torch.Size([8, 2, 2]) | 6000 | 150 | 28 | 241 | 96 | 27 shape torch.Size([16, 2, 2]) | 12000 | 275 | 28 | 242 | 96 | 27 shape torch.Size([32, 2, 2]) | 23980 | 530 | 28 | 246 | 97 | 28 shape torch.Size([64, 2, 2]) | 48000 | 1000 | 28 | 244 | 96 | 27 shape torch.Size([128, 2, 2]) | 96000 | 2063 | 28 | 245 | 96 | 28 shape torch.Size([512, 2, 2]) | 382000 | 8300 | 28 | 257 | 97 | 28 shape torch.Size([1024, 2, 2]) | 764000 | 20000 | 28 | 271 | 97 | 28 shape torch.Size([1, 8, 8]) | 749 | 34 | 28 | 243 | 88 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 28 | 244 | 97 | 28 shape torch.Size([4, 8, 8]) | 2988 | 83 | 28 | 244 | 100 | 28 shape torch.Size([8, 8, 8]) | 5980 | 150 | 28 | 245 | 97 | 28 shape torch.Size([16, 8, 8]) | 12000 | 278 | 28 | 246 | 96 | 28 shape torch.Size([32, 8, 8]) | 23910 | 536 | 28 | 249 | 98 | 28 shape torch.Size([64, 8, 8]) | 47800 | 1100 | 28 | 247 | 96 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2075 | 28 | 248 | 96 | 28 shape torch.Size([512, 8, 8]) | 382100 | 8300 | 28 | 270 | 97 | 28 shape torch.Size([1024, 8, 8]) | 764100 | 16400 | 28 | 291 | 100 | 28 shape torch.Size([1, 16, 16]) | 750 | 33 | 28 | 248 | 88 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 250 | 97 | 28 shape torch.Size([4, 16, 16]) | 2996 | 83 | 28 | 250 | 100 | 28 shape torch.Size([8, 16, 16]) | 5980 | 147 | 28 | 251 | 97 | 28 shape torch.Size([16, 16, 16]) | 11900 | 274 | 28 | 251 | 97 | 28 shape torch.Size([32, 16, 16]) | 24040 | 527 | 28 | 252 | 97 | 28 shape torch.Size([64, 16, 16]) | 47800 | 1037 | 28 | 251 | 100 | 28 shape torch.Size([128, 16, 16]) | 95600 | 2044 | 28 | 252 | 98 | 28 shape torch.Size([512, 16, 16]) | 388200 | 8100 | 28 | 280 | 100 | 28 shape torch.Size([1024, 16, 16]) | 769700 | 16000 | 28 | 322 | 117 | 28 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 255 | 88 | 28 shape torch.Size([2, 32, 32]) | 1510 | 50 | 28 | 256 | 97 | 28 shape torch.Size([4, 32, 32]) | 3022 | 82 | 31 | 256 | 97 | 30 shape torch.Size([8, 32, 32]) | 6000 | 140 | 31 | 257 | 100 | 31 shape torch.Size([16, 32, 32]) | 12000 | 281 | 31 | 258 | 96 | 31 shape torch.Size([32, 32, 32]) | 24150 | 563 | 35 | 258 | 96 | 35 shape torch.Size([64, 32, 32]) | 48300 | 1119 | 36 | 258 | 97 | 36 shape torch.Size([128, 32, 32]) | 96500 | 2235 | 43 | 261 | 96 | 43 shape torch.Size([512, 32, 32]) | 383100 | 8930 | 82 | 317 | 191 | 82 shape torch.Size([1024, 32, 32]) | 766300 | 19200 | 122 | 400 | 312 | 122 shape torch.Size([1, 64, 64]) | 760 | 33 | 55 | 272 | 68 | 34 shape torch.Size([2, 64, 64]) | 1500 | 52 | 58 | 273 | 85 | 52 shape torch.Size([4, 64, 64]) | 3127 | 102 | 65 | 273 | 150 | 65 shape torch.Size([8, 64, 64]) | 6070 | 201 | 65 | 275 | 278 | 65 shape torch.Size([16, 64, 64]) | 12000 | 399 | 66 | 274 | 95 | 67 shape torch.Size([32, 64, 64]) | 23900 | 796 | 73 | 275 | 97 | 73 shape torch.Size([64, 64, 64]) | 48000 | 1594 | 75 | 283 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3177 | 96 | 292 | 176 | 96 shape torch.Size([512, 64, 64]) | 379300 | 13520 | 208 | 426 | 551 | 208 shape torch.Size([1024, 64, 64]) | 758700 | 27100 | 306 | 570 | 919 | 306 shape torch.Size([1, 128, 128]) | 750 | 42 | 115 | 306 | 90 | 42 shape torch.Size([2, 128, 128]) | 1500 | 82 | 122 | 307 | 164 | 83 shape torch.Size([4, 128, 128]) | 2966 | 162 | 136 | 307 | 301 | 136 shape torch.Size([8, 128, 128]) | 5930 | 317 | 137 | 308 | 578 | 138 shape torch.Size([16, 128, 128]) | 12000 | 635 | 143 | 316 | 199 | 143 shape torch.Size([32, 128, 128]) | 23700 | 1266 | 152 | 322 | 241 | 152 shape torch.Size([64, 128, 128]) | 48000 | 2668 | 177 | 337 | 322 | 177 shape torch.Size([128, 128, 128]) | 96000 | 5366 | 228 | 365 | 514 | 228 shape torch.Size([512, 128, 128]) | 379400 | 21490 | 502 | 620 | 1697 | 502 shape torch.Size([1024, 128, 128]) | 755700 | 43040 | 764 | 903 | 3040 | 770 shape torch.Size([1, 256, 256]) | 750 | 70 | 235 | 383 | 178 | 72 shape torch.Size([2, 256, 256]) | 2000 | 138 | 250 | 384 | 329 | 139 shape torch.Size([4, 256, 256]) | 2988 | 277 | 279 | 404 | 655 | 278 shape torch.Size([8, 256, 256]) | 6100 | 546 | 283 | 420 | 1321 | 286 shape torch.Size([16, 256, 256]) | 12100 | 1149 | 330 | 441 | 472 | 330 shape torch.Size([32, 256, 256]) | 24040 | 2303 | 359 | 453 | 634 | 360 shape torch.Size([64, 256, 256]) | 48000 | 4626 | 408 | 472 | 925 | 408 shape torch.Size([128, 256, 256]) | 94700 | 9247 | 543 | 543 | 1582 | 543 shape torch.Size([512, 256, 256]) | 372000 | 37030 | 1310 | 1185 | 5711 | 1310 shape torch.Size([1024, 256, 256]) | 747200 | 74100 | 2116 | 1910 | 10660 | 2122 ``` </details> <details> <summary> Benchmark Results (adjoint=True) </summary> ``` [----------------------------------------------------------------------------------------- linalg.lu_solve CUDA Adjoint -----------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped cusolver | lu_solve batched cublas | lu_solve batched magma | lu_solve unpack+solve_triangular | lu_solve heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- shape torch.Size([1, 1, 1]) | 749 | 34 | 28 | 33 | 98 | 27 shape torch.Size([2, 1, 1]) | 1500 | 50 | 28 | 50 | 110 | 27 shape torch.Size([4, 1, 1]) | 3005 | 82 | 28 | 81 | 110 | 27 shape torch.Size([8, 1, 1]) | 5999 | 145 | 28 | 140 | 110 | 27 shape torch.Size([16, 1, 1]) | 12000 | 273 | 28 | 77 | 110 | 27 shape torch.Size([32, 1, 1]) | 24000 | 522 | 28 | 78 | 110 | 27 shape torch.Size([64, 1, 1]) | 48000 | 1000 | 28 | 77 | 100 | 27 shape torch.Size([128, 1, 1]) | 96000 | 2029 | 28 | 78 | 110 | 27 shape torch.Size([512, 1, 1]) | 383300 | 8100 | 28 | 78 | 110 | 28 shape torch.Size([1024, 1, 1]) | 767500 | 16100 | 28 | 77 | 100 | 27 shape torch.Size([1, 2, 2]) | 753 | 33 | 28 | 33 | 99 | 28 shape torch.Size([2, 2, 2]) | 1500 | 50 | 28 | 50 | 110 | 28 shape torch.Size([4, 2, 2]) | 3002 | 82 | 28 | 80 | 100 | 27 shape torch.Size([8, 2, 2]) | 6000 | 145 | 28 | 144 | 107 | 27 shape torch.Size([16, 2, 2]) | 12000 | 271 | 28 | 78 | 110 | 27 shape torch.Size([32, 2, 2]) | 24120 | 524 | 28 | 78 | 110 | 28 shape torch.Size([64, 2, 2]) | 48300 | 1030 | 28 | 78 | 111 | 27 shape torch.Size([128, 2, 2]) | 96100 | 2041 | 28 | 78 | 107 | 28 shape torch.Size([512, 2, 2]) | 383000 | 8100 | 28 | 79 | 108 | 28 shape torch.Size([1024, 2, 2]) | 766100 | 16000 | 28 | 78 | 110 | 28 shape torch.Size([1, 8, 8]) | 750 | 34 | 28 | 34 | 99 | 28 shape torch.Size([2, 8, 8]) | 1500 | 50 | 10000 28 | 50 | 107 | 28 shape torch.Size([4, 8, 8]) | 2998 | 82 | 28 | 82 | 110 | 28 shape torch.Size([8, 8, 8]) | 5990 | 146 | 28 | 150 | 107 | 28 shape torch.Size([16, 8, 8]) | 11980 | 272 | 28 | 79 | 107 | 28 shape torch.Size([32, 8, 8]) | 23970 | 530 | 28 | 79 | 110 | 28 shape torch.Size([64, 8, 8]) | 47900 | 1040 | 28 | 79 | 108 | 28 shape torch.Size([128, 8, 8]) | 96000 | 2048 | 28 | 78 | 108 | 28 shape torch.Size([512, 8, 8]) | 383700 | 8100 | 28 | 80 | 108 | 28 shape torch.Size([1024, 8, 8]) | 766200 | 16300 | 28 | 80 | 108 | 28 shape torch.Size([1, 16, 16]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 16, 16]) | 1500 | 50 | 28 | 50 | 110 | 28 [85/469] shape torch.Size([4, 16, 16]) | 3001 | 81 | 28 | 82 | 108 | 28 shape torch.Size([8, 16, 16]) | 6000 | 145 | 28 | 140 | 110 | 28 shape torch.Size([16, 16, 16]) | 12000 | 276 | 28 | 79 | 110 | 28 shape torch.Size([32, 16, 16]) | 23870 | 549 | 28 | 79 | 110 | 28 shape torch.Size([64, 16, 16]) | 47900 | 1098 | 29 | 80 | 100 | 28 shape torch.Size([128, 16, 16]) | 95800 | 2184 | 28 | 79 | 108 | 28 shape torch.Size([512, 16, 16]) | 386900 | 8769 | 28 | 80 | 108 | 28 shape torch.Size([1024, 16, 16]) | 769800 | 17460 | 37 | 80 | 107 | 37 shape torch.Size([1, 32, 32]) | 760 | 33 | 28 | 34 | 99 | 28 shape torch.Size([2, 32, 32]) | 1500 | 50 | 28 | 50 | 110 | 29 shape torch.Size([4, 32, 32]) | 3021 | 86 | 31 | 84 | 110 | 32 shape torch.Size([8, 32, 32]) | 6040 | 167 | 32 | 167 | 108 | 32 shape torch.Size([16, 32, 32]) | 12100 | 330 | 33 | 78 | 107 | 33 shape torch.Size([32, 32, 32]) | 24150 | 662 | 35 | 78 | 110 | 35 shape torch.Size([64, 32, 32]) | 48200 | 1323 | 36 | 79 | 110 | 36 shape torch.Size([128, 32, 32]) | 97000 | 2637 | 44 | 79 | 110 | 43 shape torch.Size([512, 32, 32]) | 382500 | 10580 | 83 | 180 | 198 | 83 shape torch.Size([1024, 32, 32]) | 766600 | 22670 | 123 | 260 | 318 | 120 shape torch.Size([1, 64, 64]) | 760 | 33 | 58 | 33 | 88 | 34 shape torch.Size([2, 64, 64]) | 1520 | 60 | 60 | 60 | 109 | 59 shape torch.Size([4, 64, 64]) | 3016 | 115 | 67 | 119 | 132 | 66 shape torch.Size([8, 64, 64]) | 6120 | 230 | 67 | 233 | 180 | 68 shape torch.Size([16, 64, 64]) | 12100 | 457 | 69 | 86 | 107 | 69 shape torch.Size([32, 64, 64]) | 24000 | 912 | 74 | 95 | 100 | 74 shape torch.Size([64, 64, 64]) | 48000 | 1833 | 76 | 106 | 123 | 76 shape torch.Size([128, 64, 64]) | 95000 | 3636 | 97 | 163 | 183 | 97 shape torch.Size([512, 64, 64]) | 380600 | 15600 | 210 | 464 | 549 | 210 shape torch.Size([1024, 64, 64]) | 761200 | 31140 | 308 | 741 | 918 | 308 shape torch.Size([1, 128, 128]) | 756 | 46 | 120 | 47 | 89 | 46 shape torch.Size([2, 128, 128]) | 1500 | 91 | 123 | 89 | 110 | 92 shape torch.Size([4, 128, 128]) | 2994 | 178 | 139 | 180 | 131 | 139 shape torch.Size([8, 128, 128]) | 5960 | 350 | 140 | 354 | 208 | 142 shape torch.Size([16, 128, 128]) | 12000 | 701 | 144 | 177 | 198 | 143 shape torch.Size([32, 128, 128]) | 23870 | 1401 | 155 | 225 | 246 | 155 shape torch.Size([64, 128, 128]) | 47600 | 2948 | 179 | 288 | 323 | 180 shape torch.Size([128, 128, 128]) | 96000 | 5910 | 231 | 442 | 512 | 231 shape torch.Size([512, 128, 128]) | 381200 | 23640 | 519 | 1400 | 1700 | 519 shape torch.Size([1024, 128, 128]) | 755800 | 47340 | 794 | 2436 | 3018 | 794 shape torch.Size([1, 256, 256]) | 760 | 74 | 246 | 77 | 88 | 78 shape torch.Size([2, 256, 256]) | 1510 | 150 | 256 | 150 | 117 | 150 shape torch.Size([4, 256, 256]) | 3030 | 296 | 284 | 296 | 209 | 284 shape torch.Size([8, 256, 256]) | 6100 | 588 | 286 | 592 | 394 | 288 shape torch.Size([16, 256, 256]) | 12200 | 1238 | 330 | 445 | 480 | 330 shape torch.Size([32, 256, 256]) | 24430 | 2476 | 368 | 568 | 629 | 367 shape torch.Size([64, 256, 256]) | 49000 | 4950 | 415 | 800 | 921 | 414 shape torch.Size([128, 256, 256]) | 96000 | 9900 | 552 | 1330 | 1579 | 553 shape torch.Size([512, 256, 256]) | 369400 | 39580 | 1410 | 4614 | 5616 | 1410 shape torch.Size([1024, 256, 256]) | 716200 | 79200 | 2270 | 8472 | 10500 | 2277 Times are in microseconds (us). ``` </details> To generate the results below, I put the backend I wanted to test at the beginning of the function `lu_solve_kernel`, followed by a `break;`. Then I run the following script, changing the variable `name`. For the `lu_solve unpack+solve_triangular`, I also changed the `stmt` variable (uncomenting the commented one) <details> <summary> Benchmarking script </summary> ```python import torch import pickle import itertools from functools import partial from torch.utils.benchmark import Timer, Compare name = "heuristic" label = "lu_solve {}".format(name) shapes = [1, 2, 8, 16, 32, 64, 128, 256] batches = [(1,), (2,), (4,), (8,), (16,), (32,), (64,), (128,), (512,), (1024,)] results = [] make_arg = partial(torch.randn, dtype=torch.float32, device="cuda") def f(LU, pivots, B, adjoint): P, L, U = torch.lu_unpack(LU, pivots) if adjoint: X = torch.linalg.solve_triangular(U.mH, B, upper=False) return P @ torch.linalg.solve_triangular(L.mH, X, upper=True, unitriangular=True, out=X) else: X = P.mT @ B X = torch.linalg.solve_triangular(L, X, upper=False, unitriangular=True, out=X) return torch.linalg.solve_triangular(U, X, upper=True, out=X) for n, batch in itertools.product(shapes, batches): LU, pivots = torch.linalg.lu_factor(make_arg(batch + (n, n))) B = make_arg(batch + (n, 1)) print(LU.shape) stmt = "torch.linalg.lu_solve(LU, pivots, B, adjoint=adjoint)" #stmt = "f(LU, pivots, B, adjoint=adjoint)" for adjoint in (True, False): timer = Timer(stmt, globals=globals(), label="linalg.lu_solve CUDA{}".format(" Adjoint" if adjoint else ""), description=label, sub_label=f"shape {LU.shape}", num_threads=1) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open("{}_lu_solve.pickle".format(name), 'wb') as f: pickle.dump(results, f) ``` </details> Finally, I joined all the results with the following script: <details> <summary> Script to join the results </summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "looped_magma", "looped cusolver", "batched cublas", "batched magma", "unpack+solve_triangular", "heuristic", ] timers = [] for name in files: with open("{}_lu_solve.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> ### Fix for Magma's batched lu_solve when `adjoint=True` I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future. <details> <summary> Fix for MAGMA's issue with `adjoint=True` </summary> ```cpp auto lu_solve_batched_magma_fn = [m](const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) { if (trans == TransposeType::NoTranspose) { lu_solve_batched_magma(LU, pivots, B, trans); return; } // There's a bug in magma for the other cases, so we need to properly perform mT or mH on LU // The LU of the transpose is not the transpose of the LU // We need to do LU = LDU' = L'U' where L' = LD, U' = D^{-1}U and D = diag(U) auto diag = LU.diagonal(0, -2, -1); auto LU_f = LU.tril(-1).mul_(diag.unsqueeze(-2)) + LU.triu(1).div_(diag.unsqueeze(-1)); LU_f.diagonal(0, -2, -1).copy_(diag); if (trans == TransposeType::ConjTranspose) { LU_f = LU_f.conj_physical(); } LU_f.transpose(-2, -1); // At this point LU_f is F-contiguous, because triu / tril / conj_phisical return contiguous tensors TORCH_INTERNAL_ASSERT_DEBUG_ONLY(LU_f.mT().is_contiguous()); // Trivial permutation auto pivots_aux = at::arange(1, m + 1, pivots.options()).expand_as(pivots).contiguous(); lu_solve_batched_magma(LU_f, pivots_aux, B, TransposeType::NoTranspose); // We then need to multiply B by P on the right as (PLU)^T = B iff U^TL^T = BP // Fill `perm` with the identity permutation (perhaps batched) // This is faster than torch.lu_unpack + matmul, as this logic is borrowed from lu_unpack const auto perm = at::arange(m, pivots.options().dtype(kLong)).expand(pivots.sizes()).contiguous(); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(pivots.sizes(), /*squash_dim=*/pivots.dim() - 1) .add_output(perm) .add_input(pivots) .build(); unpack_pivots_stub(pivots.device().type(), iter, m); B.scatter_(-2, perm.unsqueeze(-1).expand_as(B), B.clone()); }; ``` </details> Fixes #61657 [ghstack-poisoned]
…n B is a matrix" When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 16) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 10000 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (512, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | …
When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 341A | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 16) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 F438 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (512, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | 12000 …
…n B is a matrix" When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 1 10000 6) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (512, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | …
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... I'll post the benchmarks in a second ghstack-source-id: a1d542f Pull Request resolved: #79838
When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) 10000 (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 16) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (512, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | 12000 …
…n B is a matrix" When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 10000 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 16) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (512, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | …
When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 17AE | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | F41A 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 16) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (512, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | 12000 …
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... I'll post the benchmarks in a second ghstack-source-id: da24197 Pull Request resolved: #79838
…n B is a matrix" When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 10000 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 16) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (512, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | …
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... I'll post the benchmarks in a second ghstack-source-id: 196c2bf Pull Request resolved: #79838
When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 16) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (512, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 10000 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | 12000 …
…n B is a matrix" When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 16) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (5 10000 12, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | …
When linalg_lu_solve was added in https://github.com/pytorch/pytorch/pull/72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... **Edit:** These new heuristics yield a **x3-x4 performance improvement** over the previous ones when the **rhs is a square matrix or batch of square matrices** (which is the case when computing inverses, gradients for the determinant, etc). <details> <summary> Benchmark Results (adjoint=False) </summary> ``` [--------------------------------------------------------------------------------------------------------- linalg.lu_solve CUDA --------------------------------------------------------------------------------------------------------] | lu_solve looped_magma | lu_solve looped_cusolver | lu_solve batched_cublas | lu_solve batched_magma | lu_solve unpack+solve_triangular | lu_solve heuristic | previous_heuristic 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1, 1) (1, 1, 1) | 591 | 31 | 25 | 50 | 74 | 27 | 26 (2, 1, 1) (2, 1, 1) | 1200 | 47 | 25 | 41 | 81 | 25 | 26 (4, 1, 1) (4, 1, 1) | 2340 | 79 | 25 | 42 | 81 | 26 | 26 (8, 1, 1) (8, 1, 1) | 4680 | 140 | 25 | 43 | 84 | 25 | 26 (16, 1, 1) (16, 1, 1) | 9400 | 268 | 25 | 45 | 82 | 25 | 26 (32, 1, 1) (32, 1, 1) | 18700 | 519 | 25 | 41 | 81 | 26 | 26 (64, 1, 1) (64, 1, 1) | 37240 | 1020 | 25 | 39 | 83 | 26 | 26 (128, 1, 1) (128, 1, 1) | 75000 | 2021 | 26 | 41 | 81 | 26 | 26 (512, 1, 1) (512, 1, 1) | 299400 | 8000 | 25 | 50 | 83 | 26 | 25 (1024, 1, 1) (1024, 1, 1) | 598600 | 16100 | 26 | 55 | 82 | 26 | 26 (1, 1, 1) (1, 1, 8) | 596 | 31 | 25 | 24 | 75 | 25 | 26 (2, 1, 1) (2, 1, 8) | 1190 | 47 | 25 | 25 | 83 | 25 | 26 (4, 1, 1) (4, 1, 8) | 2367 | 78 | 26 | 25 | 83 | 25 | 26 (8, 1, 1) (8, 1, 8) | 4800 | 140 | 25 | 25 | 82 | 25 | 25 (16, 1, 1) (16, 1, 8) | 9500 | 269 | 26 | 26 | 82 | 25 | 25 (32, 1, 1) (32, 1, 8) | 18900 | 536 | 26 | 26 | 82 | 26 | 26 (64, 1, 1) (64, 1, 8) | 37710 | 1066 | 26 | 26 | 84 | 26 | 25 (128, 1, 1) (128, 1, 8) | 75400 | 2131 | 26 | 27 | 82 | 26 | 26 (512, 1, 1) (512, 1, 8) | 301900 | 8520 | 26 | 33 | 80 | 26 | 26 (1024, 1, 1) (1024, 1, 8) | 603600 | 17030 | 26 | 41 | 82 | 26 | 26 (1, 1, 1) (1, 1, 64) | 577 | 23 | 25 | 26 | 75 | 26 | 26 (2, 1, 1) (2, 1, 64) | 1160 | 34 | 26 | 27 | 83 | 26 | 26 (4, 1, 1) (4, 1, 64) | 2334 | 53 | 25 | 27 | 84 | 25 | 26 (8, 1, 1) (8, 1, 64) | 4650 | 92 | 26 | 28 | 82 | 25 | 26 (16, 1, 1) (16, 1, 64) | 9290 | 170 | 25 | 29 | 82 | 26 | 26 (32, 1, 1) (32, 1, 64) | 18600 | 326 | 25 | 29 | 84 | 26 | 26 (64, 1, 1) (64, 1, 64) | 37160 | 634 | 26 | 32 | 83 | 26 | 26 (128, 1, 1) (128, 1, 64) | 74000 | 1260 | 26 | 31 | 82 | 26 | 26 (512, 1, 1) (512, 1, 64) | 297300 | 4960 | 26 | 42 | 82 | 26 | 26 (1024, 1, 1) (1024, 1, 64) | 595400 | 9900 | 43 | 57 | 81 | 43 | 43 (1, 1, 1) (1, 1, 256) | 583 | 24 | 26 | 28 | 75 | 35 | 26 (2, 1, 1) (2, 1, 256) | 1170 | 34 | 25 | 28 | 82 | 27 | 26 (4, 1, 1) (4, 1, 256) | 2328 | 53 | 25 | 28 | 84 | 28 | 26 (8, 1, 1) (8, 1, 256) | 4660 | 93 | 25 | 29 | 83 | 29 | 26 (16, 1, 1) (16, 1, 256) | 9300 | 170 | 25 | 30 | 82 | 30 | 26 (32, 1, 1) (32, 1, 256) | 19000 | 326 | 25 | 35 | 82 | 32 | 26 (64, 1, 1) (64, 1, 256) | 37200 | 638 | 26 | 37 | 82 | 35 | 26 (128, 1, 1) (128, 1, 256) | 74000 | 1300 | 26 | 43 | 82 | 40 | 26 (512, 1, 1) (512, 1, 256) | 297900 | 5000 | 78 | 94 | 99 | 94 | 77 (1024, 1, 1) (1024, 1, 256) | 595100 | 9900 | 146 | 200 | 172 | 164 | 145 (1, 2, 2) (1, 2, 1) | 590 | 31 | 26 | 41 | 76 | 26 | 26 (2, 2, 2) (2, 2, 1) | 1190 | 47 | 26 | 44 | 80 | 26 | 26 (4, 2, 2) (4, 2, 1) | 2357 | 79 | 26 | 41 | 83 | 26 | 26 (8, 2, 2) (8, 2, 1) | 4700 | 140 | 26 | 41 | 82 | 26 | 26 (16, 2, 2) (16, 2, 1) | 9400 | 268 | 26 | 42 | 83 | 26 | 26 (32, 2, 2) (32, 2, 1) | 19000 | 518 | 26 | 44 | 83 | 26 | 26 (64, 2, 2) (64, 2, 1) | 37710 | 1020 | 26 | 42 | 90 | 26 | 26 (128, 2, 2) (128, 2, 1) | 75000 | 2021 | 26 | 42 | 90 | 26 | 26 (512, 2, 2) (512, 2, 1) | 299500 | 8000 | 26 | 50 | 85 | 26 | 26 (1024, 2, 2) (1024, 2, 1) | 599500 | 16030 | 26 | 67 | 99 | 26 | 26 (1, 2, 2) (1, 2, 8) | 600 | 30 | 25 | 26 | 83 | 25 | 25 (2, 2, 2) (2, 2, 8) | 1200 | 46 | 25 | 26 | 85 | 25 | 25 (4, 2, 2) (4, 2, 8) | 2372 | 77 | 25 | 30 | 89 | 27 | 25 (8, 2, 2) (8, 2, 8) | 4740 | 138 | 25 | 28 | 90 | 25 | 25 (16, 2, 2) (16, 2, 8) | 9500 | 274 | 25 | 27 | 90 | 25 | 25 (32, 2, 2) (32, 2, 8) | 19000 | 544 | 25 | 29 | 90 | 25 | 25 (64, 2, 2) (64, 2, 8) | 37680 | 1080 | 25 | 28 | 91 | 25 | 25 (128, 2, 2) (128, 2, 8) | 75500 | 2161 | 25 | 30 | 90 | 25 | 25 (512, 2, 2) (512, 2, 8) | 300700 | 8635 | 25 | 39 | 91 | 25 | 25 (1024, 2, 2) (1024, 2, 8) | 604600 | 17270 | 25 | 52 | 85 | 25 | 25 (1, 2, 2) (1, 2, 64) | 577 | 23 | 25 | 27 | 79 | 25 | 25 (2, 2, 2) (2, 2, 64) | 1157 | 33 | 25 | 28 | 90 | 25 | 25 (4, 2, 2) (4, 2, 64) | 2316 | 53 | 25 | 29 | 90 | 25 | 25 (8, 2, 2) (8, 2, 64) | 4700 | 91 | 25 | 30 | 90 | 25 | 25 (16, 2, 2) (16, 2, 64) | 9290 | 168 | 25 | 30 | 91 | 25 | 25 (32, 2, 2) (32, 2, 64) | 19000 | 322 | 25 | 31 | 90 | 25 | 25 (64, 2, 2) (64, 2, 64) | 37440 | 630 | 25 | 31 | 90 | 25 | 25 (128, 2, 2) (128, 2, 64) | 75000 | 1200 | 26 | 36 | 90 | 25 | 25 (512, 2, 2) (512, 2, 64) | 300000 | 4930 | 33 | 50 | 84 | 33 | 33 (1024, 2, 2) (1024, 2, 64) | 599200 | 9800 | 56 | 70 | 83 | 55 | 56 (1, 2, 2) (1, 2, 256) | 586 | 23 | 25 | 30 | 80 | 28 | 25 (2, 2, 2) (2, 2, 256) | 1170 | 33 | 25 | 31 | 84 | 29 | 25 (4, 2, 2) (4, 2, 256) | 2347 | 53 | 25 | 32 | 83 | 29 | 25 (8, 2, 2) (8, 2, 256) | 4690 | 92 | 25 | 30 | 84 | 30 | 25 (16, 2, 2) (16, 2, 256) | 9400 | 170 | 26 | 31 | 90 | 31 | 25 (32, 2, 2) (32, 2, 256) | 19000 | 322 | 25 | 35 | 83 | 34 | 25 (64, 2, 2) (64, 2, 256) | 37540 | 628 | 25 | 38 | 83 | 36 | 25 (128, 2, 2) (128, 2, 256) | 75000 | 1260 | 37 | 44 | 83 | 43 | 37 (512, 2, 2) (512, 2, 256) | 299500 | 4960 | 100 | 110 | 133 | 103 | 103 (1024, 2, 2) (1024, 2, 256) | 585400 | 10000 | 211 | 210 | 265 | 195 | 209 (1, 8, 8) (1, 8, 1) | 577 | 33 | 26 | 44 | 76 | 26 | 26 (2, 8, 8) (2, 8, 1) | 1200 | 48 | 26 | 45 | 90 | 26 | 26 (4, 8, 8) (4, 8, 1) | 2301 | 80 | 26 | 40 | 84 | 26 | 26 (8, 8, 8) (8, 8, 1) | 4600 | 142 | 26 | 50 | 83 | 26 | 26 (16, 8, 8) (16, 8, 1) | 9100 | 267 | 26 | 50 | 82 | 26 | 26 (32, 8, 8) (32, 8, 1) | 19000 | 519 | 26 | 50 | 83 | 26 | 26 (64, 8, 8) (64, 8, 1) | 36980 | 1020 | 26 | 46 | 90 | 26 | 26 (128, 8, 8) (128, 8, 1) | 74000 | 2053 | 26 | 50 | 85 | 26 | 26 (512, 8, 8) (512, 8, 1) | 295100 | 8100 | 26 | 64 | 82 | 26 | 26 (1024, 8, 8) (1024, 8, 1) | 593200 | 16100 | 26 | 90 | 82 | 26 | 26 (1, 8, 8) (1, 8, 8) | 590 | 32 | 25 | 28 | 77 | 25 | 25 (2, 8, 8) (2, 8, 8) | 1200 | 47 | 25 | 28 | 83 | 25 | 25 (4, 8, 8) (4, 8, 8) | 2322 | 80 | 25 | 29 | 83 | 25 | 25 (8, 8, 8) (8, 8, 8) | 4640 | 145 | 25 | 28 | 83 | 25 | 26 (16, 8, 8) (16, 8, 8) | 9300 | 286 | 25 | 29 | 84 | 26 | 25 (32, 8, 8) (32, 8, 8) | 18500 | 568 | 26 | 32 | 85 | 26 | 25 (64, 8, 8) (64, 8, 8) | 37020 | 1200 | 26 | 30 | 83 | 25 | 25 (128, 8, 8) (128, 8, 8) | 74000 | 2261 | 26 | 31 | 83 | 26 | 25 (512, 8, 8) (512, 8, 8) | 294300 | 9030 | 25 | 41 | 83 | 25 | 25 (1024, 8, 8) (1024, 8, 8) | 592700 | 20000 | 26 | 60 | 82 | 26 | 26 (1, 8, 8) (1, 8, 64) | 564 | 24 | 25 | 27 | 76 | 25 | 25 (2, 8, 8) (2, 8, 64) | 1100 | 34 | 25 | 28 | 82 | 25 | 25 (4, 8, 8) (4, 8, 64) | 2289 | 53 | 25 | 29 | 84 | 25 | 25 (8, 8, 8) (8, 8, 64) | 4500 | 92 | 25 | 30 | 84 | 25 | 25 (16, 8, 8) (16, 8, 64) | 9200 | 170 | 25 | 30 | 83 | 25 | 25 (32, 8, 8) (32, 8, 64) | 18000 | 322 | 25 | 33 | 83 | 25 | 25 (64, 8, 8) (64, 8, 64) | 36870 | 629 | 28 | 32 | 82 | 27 | 28 (128, 8, 8) (128, 8, 64) | 73000 | 1250 | 36 | 38 | 82 | 36 | 35 (512, 8, 8) (512, 8, 64) | 292500 | 5000 | 79 | 88 | 104 | 78 | 79 (1024, 8, 8) (1024, 8, 64) | 582300 | 9900 | 154 | 159 | 196 | 154 | 155 (1, 8, 8) (1, 8, 256) | 582 | 25 | 25 | 28 | 76 | 28 | 25 (2, 8, 8) (2, 8, 256) | 1160 | 35 | 25 | 30 | 83 | 30 | 26 (4, 8, 8) (4, 8, 256) | 2304 | 54 | 31 | 31 | 83 | 30 | 26 (8, 8, 8) (8, 8, 256) | 4600 | 92 | 38 | 34 | 82 | 33 | 41 (16, 8, 8) (16, 8, 256) | 9200 | 169 | 64 | 37 | 82 | 37 | 65 (32, 8, 8) (32, 8, 256) | 18000 | 324 | 74 | 44 | 82 | 42 | 73 (64, 8, 8) (64, 8, 256) | 37090 | 632 | 88 | 61 | 83 | 58 | 88 (128, 8, 8) (128, 8, 256) | 74000 | 1250 | 116 | 91 | 101 | 90 | 115 (512, 8, 8) (512, 8, 256) | 304900 | 5000 | 382 | 314 | 356 | 308 | 383 (1024, 8, 8) (1024, 8, 256) | 609500 | 9800 | 774 | 615 | 672 | 590 | 775 (1, 16, 16) (1, 16, 1) | 596 | 33 | 26 | 46 | 76 | 31 | 25 (2, 16, 16) (2, 16, 1) | 1200 | 48 | 26 | 47 | 83 | 47 | 26 (4, 16, 16) (4, 16, 1) | 2371 | 79 | 26 | 49 | 83 | 26 | 26 (8, 16, 16) (8, 16, 1) | 4726 | 143 | 26 | 50 | 84 | 26 | 26 (16, 16, 16) (16, 16, 1) | 9400 | 270 | 26 | 49 | 83 | 26 | 26 (32, 16, 16) (32, 16, 1) | 18800 | 520 | 26 | 50 | 83 | 26 | 26 (64, 16, 16) (64, 16, 1) | 37280 | 1020 | 26 | 48 | 84 | 26 | 25 (128, 16, 16) (128, 16, 1) | 74000 | 2018 | 26 | 50 | 83 | 26 | 26 (512, 16, 16) (512, 16, 1) | 296300 | 8000 | 26 | 77 | 82 | 26 | 25 (1024, 16, 16) (1024, 16, 1) | 593800 | 16000 | 28 | 110 | 82 | 27 | 28 (1, 16, 16) (1, 16, 8) | 590 | 30 | 25 | 27 | 76 | 30 | 25 (2, 16, 16) (2, 16, 8) | 1200 | 46 | 25 | 29 | 82 | 46 | 25 (4, 16, 16) (4, 16, 8) | 2381 | 82 | 25 | 30 | 82 | 26 | 25 (8, 16, 16) (8, 16, 8) | 4700 | 200 | 25 | 30 | 84 | 26 | 25 (16, 16, 16) (16, 16, 8) | 9400 | 320 | 25 | 30 | 83 | 26 | 25 (32, 16, 16) (32, 16, 8) | 19000 | 640 | 25 | 31 | 83 | 26 | 25 (64, 16, 16) (64, 16, 8) | 37260 | 1234 | 25 | 31 | 91 | 26 | 25 (128, 16, 16) (128, 16, 8) | 74000 | 2542 | 25 | 33 | 90 | 26 | 25 (512, 16, 16) (512, 16, 8) | 299800 | 10000 | 36 | 45 | 82 | 36 | 36 (1024, 16, 16) (1024, 16, 8) | 602300 | 20000 | 56 | 63 | 93 | 56 | 56 (1, 16, 16) (1, 16, 64) | 582 | 24 | 25 | 29 | 76 | 23 | 25 (2, 16, 16) (2, 16, 64) | 1200 | 34 | 25 | 30 | 82 | 32 | 25 (4, 16, 16) (4, 16, 64) | 2318 | 53 | 25 | 30 | 83 | 83 | 25 (8, 16, 16) (8, 16, 64) | 4680 | 92 | 28 | 31 | 83 | 84 | 27 (16, 16, 16) (16, 16, 64) | 9290 | 181 | 42 | 32 | 83 | 84 | 40 (32, 16, 16) (32, 16, 64) | 19000 | 357 | 69 | 35 | 89 | 84 | 70 (64, 16, 16) (64, 16, 64) | 37160 | 714 | 78 | 38 | 82 | 84 | 77 (128, 16, 16) (128, 16, 64) | 74000 | 1420 | 94 | 54 | 81 | 82 | 92 (512, 16, 16) (512, 16, 64) | 299100 | 5836 | 222 | 145 | 235 | 233 | 222 (1024, 16, 16) (1024, 16, 64) | 600700 | 11730 | 572 | 263 | 411 | 410 | 566 (1, 16, 16) (1, 16, 256) | 594 | 23 | 43 | 33 | 76 | 23 | 40 (2, 16, 16) (2, 16, 256) | 1200 | 34 | 48 | 34 | 88 | 33 | 49 (4, 16, 16) (4, 16, 256) | 2378 | 62 | 94 | 35 | 82 | 84 | 91 (8, 16, 16) (8, 16, 256) | 4730 | 118 | 127 | 38 | 82 | 84 | 120 (16, 16, 16) (16, 16, 256) | 9500 | 244 | 142 | 45 | 83 | 84 | 141 (32, 16, 16) (32, 16, 256) | 19000 | 477 | 240 | 59 | 81 | 82 | 239 (64, 16, 16) (64, 16, 256) | 37610 | 969 | 279 | 90 | 118 | 118 | 277 (128, 16, 16) (128, 16, 256) | 76000 | 1960 | 340 | 156 | 218 | 217 | 343 (512, 16, 16) (512, 16, 256) | 302200 | 7910 | 1151 | 542 | 765 | 770 | 1145 (1024, 16, 16) (1024, 16, 256) | 603400 | 15810 | 2404 | 1040 | 1500 | 1490 | 2299 (1, 32, 32) (1, 32, 1) | 591 | 31 | 28 | 53 | 77 | 30 | 28 (2, 32, 32) (2, 32, 1) | 1200 | 46 | 28 | 54 | 83 | 46 | 28 (4, 32, 32) (4, 32, 1) | 2312 | 78 | 31 | 55 | 83 | 31 | 31 (8, 32, 32) (8, 32, 1) | 4600 | 144 | 31 | 56 | 83 | 31 | 31 (16, 32, 32) (16, 32, 1) | 9300 | 287 | 32 | 56 | 82 | 32 | 32 (32, 32, 32) (32, 32, 1) | 18700 | 572 | 35 | 56 | 83 | 35 | 35 (64, 32, 32) (64, 32, 1) | 37620 | 1140 | 36 | 55 | 83 | 36 | 36 (128, 32, 32) (128, 32, 1) | 75000 | 2275 | 43 | 59 | 82 | 43 | 43 (512, 32, 32) (512, 32, 1) | 300700 | 9080 | 83 | 100 | 131 | 83 | 83 (1024, 32, 32) (1024, 32, 1) | 602000 | 19500 | 123 | 190 | 175 | 123 | 123 (1, 32, 32) (1, 32, 8) | 600 | 30 | 25 | 252 | 76 | 29 | 25 (2, 32, 32) (2, 32, 8) | 1200 | 49 | 25 | 253 | 83 | 49 | 25 (4, 32, 32) (4, 32, 8) | 2388 | 94 | 30 | 254 | 83 | 30 | 31 (8, 32, 32) (8, 32, 8) | 4800 | 183 | 33 | 257 | 83 | 32 | 32 (16, 32, 32) (16, 32, 8) | 9600 | 365 | 34 | 258 | 83 | 35 | 35 (32, 32, 32) (32, 32, 8) | 19100 | 727 | 45 | 260 | 83 | 44 | 44 (64, 32, 32) (64, 32, 8) | 38230 | 1453 | 45 | 260 | 83 | 45 | 45 (128, 32, 32) (128, 32, 8) | 76400 | 2901 | 51 | 262 | 83 | 50 | 51 (512, 32, 32) (512, 32, 8) | 306900 | 12000 | 96 | 303 | 142 | 96 | 95 (1024, 32, 32) (1024, 32, 8) | 610900 | 24520 | 162 | 348 | 191 | 162 | 161 (1, 32, 32) (1, 32, 64) | 586 | 23 | 37 | 264 | 76 | 23 | 38 (2, 32, 32) (2, 32, 64) | 1200 | 39 | 44 | 268 | 85 | 41 | 44 (4, 32, 32) (4, 32, 64) | 2367 | 75 | 76 | 270 | 83 | 83 | 77 (8, 32, 32) (8, 32, 64) | 4750 | 145 | 86 | 273 | 83 | 83 | 84 (16, 32, 32) (16, 32, 64) | 9500 | 298 | 124 | 275 | 83 | 83 | 125 (32, 32, 32) (32, 32, 64) | 18900 | 594 | 239 | 272 | 82 | 82 | 238 (64, 32, 32) (64, 32, 64) | 38020 | 1157 | 265 | 276 | 110 | 110 | 262 (128, 32, 32) (128, 32, 64) | 76000 | 2320 | 320 | 293 | 179 | 180 | 306 (512, 32, 32) (512, 32, 64) | 304300 | 9611 | 840 | 533 | 613 | 614 | 835 (1024, 32, 32) (1024, 32, 64) | 605700 | 19190 | 2000 | 910 | 1130 | 1135 | 1900 (1, 32, 32) (1, 32, 256) | 610 | 34 | 87 | 257 | 77 | 31 | 84 (2, 32, 32) (2, 32, 256) | 1300 | 67 | 200 | 265 | 83 | 63 | 179 (4, 32, 32) (4, 32, 256) | 2441 | 126 | 270 | 264 | 84 | 84 | 247 (8, 32, 32) (8, 32, 256) | 4900 | 238 | 274 | 270 | 83 | 83 | 278 (16, 32, 32) (16, 32, 256) | 9600 | 478 | 450 | 281 | 107 | 107 | 433 (32, 32, 32) (32, 32, 256) | 20000 | 972 | 960 | 295 | 172 | 174 | 903 (64, 32, 32) (64, 32, 256) | 38910 | 1915 | 1000 | 352 | 308 | 309 | 993 (128, 32, 32) (128, 32, 256) | 80000 | 3833 | 1200 | 620 | 572 | 572 | 1150 (512, 32, 32) (512, 32, 256) | 313000 | 15410 | 3723 | 2000 | 2173 | 2161 | 3621 (1024, 32, 32) (1024, 32, 256) | 625200 | 30880 | 8000 | 3277 | 4276 | 4279 | 7700 (1, 64, 64) (1, 64, 1) | 609 | 30 | 55 | 68 | 66 | 30 | 30 (2, 64, 64) (2, 64, 1) | 1201 | 53 | 61 | 70 | 85 | 53 | 53 (4, 64, 64) (4, 64, 1) | 2393 | 102 | 70 | 70 | 110 | 65 | 65 (8, 64, 64) (8, 64, 1) | 4824 | 200 | 68 | 71 | 155 | 66 | 66 (16, 64, 64) (16, 64, 1) | 9600 | 401 | 70 | 72 | 84 | 67 | 67 (32, 64, 64) (32, 64, 1) | 19000 | 802 | 76 | 72 | 88 | 73 | 73 (64, 64, 64) (64, 64, 1) | 38890 | 1601 | 76 | 72 | 90 | 76 | 76 (128, 64, 64) (128, 64, 1) | 76000 | 3198 | 100 | 77 | 115 | 96 | 96 (512, 64, 64) (512, 64, 1) | 301700 | 13720 | 210 | 216 | 286 | 209 | 209 (1024, 64, 64) (1024, 64, 1) | 598400 | 28090 | 308 | 355 | 386 | 308 | 308 (1, 64, 64) (1, 64, 8) | 600 | 37 | 48 | 302 | 66 | 37 | 36 (2, 64, 64) (2, 64, 8) | 1190 | 69 | 52 | 310 | 86 | 69 | 69 (4, 64, 64) (4, 64, 8) | 2383 | 135 | 64 | 312 | 112 | 64 | 64 (8, 64, 64) (8, 64, 8) | 4730 | 267 | 67 | 317 | 200 | 66 | 66 (16, 64, 64) (16, 64, 8) | 9500 | 535 | 73 | 318 | 83 | 73 | 73 (32, 64, 64) (32, 64, 8) | 19000 | 1060 | 94 | 323 | 83 | 94 | 94 (64, 64, 64) (64, 64, 8) | 37620 | 2129 | 97 | 329 | 86 | 96 | 97 (128, 64, 64) (128, 64, 8) | 75000 | 4245 | 113 | 343 | 119 | 113 | 113 (512, 64, 64) (512, 64, 8) | 299900 | 18150 | 240 | 444 | 312 | 240 | 240 (1024, 64, 64) (1024, 64, 8) | 602200 | 36350 | 410 | 608 | 465 | 410 | 410 (1, 64, 64) (1, 64, 64) | 580 | 35 | 80 | 314 | 59 | 57 | 34 (2, 64, 64) (2, 64, 64) | 1160 | 64 | 86 | 322 | 71 | 70 | 64 (4, 64, 64) (4, 64, 64) | 2323 | 131 | 162 | 326 | 85 | 85 | 162 (8, 64, 64) (8, 64, 64) | 4640 | 254 | 177 | 331 | 147 | 147 | 177 (16, 64, 64) (16, 64, 64) | 9320 | 505 | 271 | 345 | 108 | 108 | 274 (32, 64, 64) (32, 64, 64) | 18700 | 980 | 532 | 361 | 151 | 151 | 533 (64, 64, 64) (64, 64, 64) | 37290 | 1998 | 595 | 374 | 238 | 238 | 593 (128, 64, 64) (128, 64, 64) | 75000 | 4124 | 717 | 465 | 414 | 415 | 718 (512, 64, 64) (512, 64, 64) | 299900 | 16430 | 2522 | 2114 | 1511 | 1500 | 2739 (1024, 64, 64) (1024, 64, 64) | 596100 | 32950 | 5900 | 4484 | 2859 | 2852 | 5600 (1, 64, 64) (1, 64, 256) | 607 | 59 | 259 | 319 | 59 | 57 | 60 (2, 64, 64) (2, 64, 256) | 1200 | 112 | 380 | 328 | 10000 71 | 70 | 116 (4, 64, 64) (4, 64, 256) | 2429 | 225 | 535 | 332 | 106 | 106 | 535 (8, 64, 64) (8, 64, 256) | 4850 | 444 | 576 | 349 | 189 | 190 | 578 (16, 64, 64) (16, 64, 256) | 9690 | 882 | 960 | 386 | 232 | 232 | 962 (32, 64, 64) (32, 64, 256) | 19000 | 1753 | 2022 | 472 | 399 | 400 | 2024 (64, 64, 64) (64, 64, 256) | 38960 | 3546 | 2249 | 1126 | 742 | 741 | 2250 (128, 64, 64) (128, 64, 256) | 78000 | 7100 | 2730 | 2669 | 1416 | 1410 | 2734 (512, 64, 64) (512, 64, 256) | 308900 | 28150 | 12000 | 11000 | 5600 | 5568 | 12000 (1024, 64, 64) (1024, 64, 256) | 615400 | 56470 | 25350 | 22430 | 11000 | 11000 | 25100 (1, 128, 128) (1, 128, 1) | 584 | 42 | 115 | 100 | 67 | 42 | 43 (2, 128, 128) (2, 128, 1) | 1200 | 83 | 122 | 100 | 86 | 83 | 83 (4, 128, 128) (4, 128, 1) | 2312 | 162 | 136 | 100 | 113 | 137 | 136 (8, 128, 128) (8, 128, 1) | 4620 | 324 | 138 | 100 | 197 | 138 | 138 (16, 128, 128) (16, 128, 1) | 9200 | 642 | 143 | 100 | 159 | 144 | 144 (32, 128, 128) (32, 128, 1) | 18400 | 1291 | 153 | 103 | 163 | 153 | 153 (64, 128, 128) (64, 128, 1) | 36770 | 2711 | 178 | 120 | 191 | 178 | 178 (128, 128, 128) (128, 128, 1) | 74000 | 5447 | 231 | 150 | 251 | 231 | 231 (512, 128, 128) (512, 128, 1) | 292800 | 21820 | 505 | 406 | 639 | 505 | 505 (1024, 128, 128) (1024, 128, 1) | 585700 | 43630 | 769 | 685 | 911 | 769 | 769 (1, 128, 128) (1, 128, 8) | 590 | 58 | 102 | 416 | 70 | 65 | 57 (2, 128, 128) (2, 128, 8) | 1170 | 111 | 109 | 436 | 98 | 98 | 111 (4, 128, 128) (4, 128, 8) | 2329 | 221 | 133 | 446 | 170 | 168 | 133 (8, 128, 128) (8, 128, 8) | 4620 | 437 | 139 | 453 | 308 | 306 | 138 (16, 128, 128) (16, 128, 8) | 9260 | 863 | 154 | 460 | 144 | 145 | 154 (32, 128, 128) (32, 128, 8) | 19000 | 1723 | 202 | 480 | 158 | 155 | 202 (64, 128, 128) (64, 128, 8) | 38360 | 3697 | 223 | 493 | 184 | 185 | 224 (128, 128, 128) (128, 128, 8) | 100000 | 7435 | 267 | 532 | 262 | 263 | 268 (512, 128, 128) (512, 128, 8) | 291300 | 29760 | 595 | 827 | 745 | 744 | 596 (1024, 128, 128) (1024, 128, 8) | 572900 | 59560 | 1290 | 1372 | 1132 | 1133 | 1308 (1, 128, 128) (1, 128, 64) | 567 | 61 | 164 | 447 | 59 | 60 | 63 (2, 128, 128) (2, 128, 64) | 1100 | 126 | 206 | 467 | 96 | 98 | 124 (4, 128, 128) (4, 128, 64) | 2848 | 252 | 350 | 490 | 167 | 168 | 352 (8, 128, 128) (8, 128, 64) | 4550 | 482 | 377 | 498 | 307 | 309 | 376 (16, 128, 128) (16, 128, 64) | 9300 | 972 | 578 | 522 | 215 | 217 | 580 (32, 128, 128) (32, 128, 64) | 19000 | 2022 | 1120 | 561 | 312 | 313 | 1120 (64, 128, 128) (64, 128, 64) | 37780 | 4128 | 1260 | 615 | 500 | 501 | 1260 (128, 128, 128) (128, 128, 64) | 75400 | 8234 | 1553 | 1111 | 876 | 876 | 1550 (512, 128, 128) (512, 128, 64) | 303200 | 32920 | 7600 | 6050 | 3345 | 3366 | 7840 (1024, 128, 128) (1024, 128, 64) | 604500 | 65880 | 16500 | 13000 | 6410 | 6400 | 16400 (1, 128, 128) (1, 128, 256) | 648 | 117 | 632 | 460 | 74 | 74 | 116 (2, 128, 128) (2, 128, 256) | 1300 | 231 | 825 | 480 | 126 | 126 | 227 (4, 128, 128) (4, 128, 256) | 2586 | 456 | 1100 | 497 | 227 | 226 | 1100 (8, 128, 128) (8, 128, 256) | 5170 | 920 | 1190 | 542 | 426 | 426 | 1190 (16, 128, 128) (16, 128, 256) | 10400 | 1841 | 2024 | 688 | 481 | 482 | 2035 (32, 128, 128) (32, 128, 256) | 20680 | 3702 | 4243 | 1220 | 836 | 836 | 4252 (64, 128, 128) (64, 128, 256) | 41350 | 7421 | 4803 | 3260 | 1600 | 1600 | 4810 (128, 128, 128) (128, 128, 256) | 83000 | 14850 | 6470 | 7600 | 3075 | 3072 | 6470 (512, 128, 128) (512, 128, 256) | 331800 | 59180 | 30510 | 30480 | 12000 …
…n B is a matrix" When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
…n B is a matrix" When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
…n B is a matrix" When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
…n B is a matrix" When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
…n B is a matrix" When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
…n B is a matrix" When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
…n B is a matrix" When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) [ghstack-poisoned]
When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. @xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) Pull Request resolved: #79838 Approved by: https://github.com/IvanYashchuk, https://github.com/albanD
…79838) Summary: When linalg_lu_solve was added in #72935 I made the big mistake of assuming that the choice of backend would not depend on number of columns of B. This turned out to be false by a large margin. This PR amends this and provides a heuristic that takes the number of columns of B into account. The heuristic is not simple and it was crafted by hand, but as the results show, it is effective. xwang233 the cusolver team should look into this one, as I was able to outperform both cublas and cusolvers algorithms by using triangular solves... The benchmarks for the heuristics are here: #79838 (comment) Pull Request resolved: #79838 Approved by: https://github.com/IvanYashchuk, https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/6fb6dc42ff9b96776e020567a8e4e8d2aa007307 Reviewed By: mehtanirav Differential Revision: D37604726 Pulled By: mehtanirav fbshipit-source-id: 66445588258f3edddd3307568a2defac4c7cf896
Stack from ghstack:
This PR adds
linalg.lu_solve
. While doing so, I found a bug in MAGMAwhen calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
outdated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
Benchmarking
Benchmark Results (adjoint=False)
Benchmark Results (adjoint=True)
To generate the results below, I put the backend I wanted to test at the beginning of the function
lu_solve_kernel
, followed by areturn;
. Then I run the following script, changing the variablename
.Benchmarking script
Finally, I joined all the results with the following script:
Script to join the results
Fix for Magma's batched lu_solve when
adjoint=True
I also developed the following fix around MAGMA's bug, but I ended up not using it, and preferring the triangular solves over it, as they were faster. I'm leaving it here in case it's useful in the future.
Fix for MAGMA's issue with `adjoint=True`
Fixes #61657