-
Notifications
You must be signed in to change notification settings - Fork 234
Closed
Labels
DocArray v2This issue is part of the rewrite; not to be merged into mainThis issue is part of the rewrite; not to be merged into main
Description
When operations on a paremetrized tensor are performed, the supposed shape of that tensor (given in the type hint) does not change, even when the operation does change the actual shape:
class Doc(BaseDocument):
tensor: TorchTensor[3,1]
d = Doc(tensor=torch.rand(size=(3, 1)))
d_t = d.tensor.transpose(0, 1)
print(d_t)
print(d_t.shape)TorchTensor[3, 1]([[0.7769, 0.8053, 0.0161]])
torch.Size([1, 3])
I see two options to tackle this:
- override
__torch_function__to correctly assign the current shape to the tensor class after every operation. This would make it that in the example above, the first print would produceTorchTensor[1, 3]([[0.7769, 0.8053, 0.0161]]). Problem with this: This transformation has to happen at every torch operation, which seems prone for things to go wrong - override
__torch_function__to return atorch.Tensorinstead of aTorchTensor. The challenge with this is that in the current state, this would make our type system useless outside of Document. For example, the following would not type check:
def my_helper(t: TorchTensor[512]):
...
class Doc(BaseDocument):
tensor: TorchTensor[512]
d = Doc(tensor=torch.rand(512))
t = d.tensor + d.tensor # t is now torch.Tensor
my_helper(t) # but this wants TorchTensor[512]To make this make sense again, we should take a look at TorchTyping and see how they achieve proper typing despite the data being torch.Tensor.
Big advantage of this: As soon as any operations occur on a TorchTensor, it turns into a torch.Tensor, meaning there is no opportunity for us to f*ck things up inside of a model.
Preliminary conclusion: Let's do it properly through option 2.
Metadata
Metadata
Assignees
Labels
DocArray v2This issue is part of the rewrite; not to be merged into mainThis issue is part of the rewrite; not to be merged into main
Type
Projects
Status
Done