Skip to content

Commit e65bab5

Browse files
committed
pytorch: add squeeze and tensor creation examples
1 parent 73b3e23 commit e65bab5

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
3+
def main():
4+
print("Squeeze Tensor Example")
5+
t = torch.tensor(
6+
[ # first dimension just has one element
7+
[1, 2, 3] # second dimension has three elements
8+
] # this would be shape [3, 1] in ggml.
9+
)
10+
# So what sequeeze does it it removes the empty dimensions.
11+
print(t.shape)
12+
t = t.squeeze()
13+
print(t.shape)
14+
15+
# And we can also add a dimension back in.
16+
t = t.unsqueeze(0) # Add a dimension at index 0
17+
print(t.shape)
18+
19+
t = t.unsqueeze(0) # Add a dimension at index 0
20+
print(t.shape)
21+
22+
# When a script is run directly the special variable __main__ is set to "__main__"
23+
print(f"__name__ = {__name__}")
24+
if __name__ == "__main__":
25+
main()

fundamentals/pytorch/src/tensor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
3+
def print_tensor(t: torch.Tensor):
4+
print(t, type(t), t.dtype, t.device, t.requires_grad)
5+
6+
def main():
7+
print("Tensor creation examples")
8+
# We can create a tensor from a Python list:
9+
t = torch.tensor([1, 2])
10+
print_tensor(t)
11+
12+
# From a list of lists:
13+
t = torch.tensor([[1, 2], [3, 4]])
14+
print_tensor(t)
15+
16+
t = torch.tensor(5.0)
17+
print_tensor(t)
18+
19+
t = torch.tensor([0, 0, 0], # Initial data
20+
dtype=None, # Data type (torch.float32, torch.int64, etc.)
21+
device=None, # Device ('cpu', 'cuda', 'cuda:0', etc.)
22+
requires_grad=False) # Track gradients for autograd
23+
print_tensor(t)
24+
25+
t = torch.range(0, 10, 2) # Start, end, how big each step should be (we know the step size here)
26+
print_tensor(t)
27+
28+
t = torch.linspace(0, 10, 5) # Start, end, number of steps (we know the number of points needed)
29+
print_tensor(t)
30+
31+
32+
if __name__ == "__main__":
33+
main()
34+

0 commit comments

Comments
 (0)