Skip to content

v2: parametrized tensor shape during operations #994

@JohannesMessner

Description

@JohannesMessner

When operations on a paremetrized tensor are performed, the supposed shape of that tensor (given in the type hint) does not change, even when the operation does change the actual shape:

class Doc(BaseDocument):
    tensor: TorchTensor[3,1]


d = Doc(tensor=torch.rand(size=(3, 1)))
d_t = d.tensor.transpose(0, 1)
print(d_t)
print(d_t.shape)
TorchTensor[3, 1]([[0.7769, 0.8053, 0.0161]])
torch.Size([1, 3])

I see two options to tackle this:

  1. override __torch_function__ to correctly assign the current shape to the tensor class after every operation. This would make it that in the example above, the first print would produce TorchTensor[1, 3]([[0.7769, 0.8053, 0.0161]]). Problem with this: This transformation has to happen at every torch operation, which seems prone for things to go wrong
  2. override __torch_function__ to return a torch.Tensor instead of a TorchTensor. The challenge with this is that in the current state, this would make our type system useless outside of Document. For example, the following would not type check:
def my_helper(t: TorchTensor[512]):
    ...

class Doc(BaseDocument):
    tensor: TorchTensor[512]

d = Doc(tensor=torch.rand(512))
t = d.tensor + d.tensor  # t is now torch.Tensor
my_helper(t)  # but this wants TorchTensor[512]

To make this make sense again, we should take a look at TorchTyping and see how they achieve proper typing despite the data being torch.Tensor.
Big advantage of this: As soon as any operations occur on a TorchTensor, it turns into a torch.Tensor, meaning there is no opportunity for us to f*ck things up inside of a model.

Preliminary conclusion: Let's do it properly through option 2.

Metadata

Metadata

Assignees

Labels

DocArray v2This issue is part of the rewrite; not to be merged into main

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions