Sorry, on my phone at the moment so I don’t think I can really type some decent code right now!
I actually think that torch’s ShardedTensor looks very promising. Essentially you can initialize a sharded tensor from an already initialized tensor, or initialize a sharded tensor on a meta device where it’s not allocated locally and each shard gets initialized on the specified remote devices (useful for extremely large tensors)
The sharding is described by a ShardingSpec, where you can either let it shard equally sized shards across the requested devices, where the split happens along a single dimension, or you can do grid sharding along multiple dimensions. They also have a more general sharding spec that allows you to choose explicitly which indices go on which devices, if you need non uniform shards.
I think once these are implemented (along with some special cases like cloned tensors, and things like that), and once the distributed autograd engine has full support for CUDA, it should be pretty easy to start building out distributed versions of common neural net operations.
The one thing (that I haven’t thought about a ton, to be frank, and I’m sure other smarter people have :)) is that you’ll end up in cases with both a sharding spec for the weights as well as for the inputs, and what’s the best way to make sure everything matches up. Is the best way to handle that custom logic for each operation? And do you have each operation just reshard the input automatically? Seems potentially like a pretty big performance pitfall.