[pytorch] view() vs. reshape() ๋น๊ต
๐ฃ Intro
- pytorch๋ฅผ ์ฌ์ฉํ๋ค๋ณด๋ฉด tensor์ shape์ ๋ฐ๊ฟ์ผํ๋ ์์
์ด ๋งค์ฐ ๋น๋ฒํ๋ค. ์๋๋ฉด ๋ชจ๋ธ์ forward๋ฅผ ํ๋๋ผ๋ 2์ฐจ์ ๋ฐ์ดํฐ์ ๊ฒฝ์ฐ๋ (B,C,W,H)์ฒ๋ผ ์ฐจ์์ ์ค์ ํด์ค์ผ ๊ฐ ๋ ์ด์ด์์ ํ์ต์ด ๊ฐ๋ฅํ๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ๋ด๋ ๋ง์ด ์ฐ๋ torch์ ๋ฉ์๋๊ฐ
reshape()
๊ณผview()
์ด๋ค.
####
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
a = torch.randn(3, 2, 4)
b = a.reshape(-1) # reshape() ๋ฉ์๋๋ฅผ ์จ์ 1์ฐจ์์ผ๋ก ๋ง๋ฆ
b[0] = 99999
a # a์๋ ์ํฅ์ ์ค๋ค
a = torch.randn(3, 2, 4)
b = a.view(-1) # reshape() ๋ฉ์๋๋ฅผ ์จ์ 1์ฐจ์์ผ๋ก ๋ง๋ฆ
b[0] = 99999
a # a์๋ ์ํฅ์ ์ค๋ค
#### ๐ก ๊น๊นํ ์์น์ฃผ์์ (`view()`) vs. ์ตํต์ฑ ์๋ ํด๊ฒฐ์ฌ (`reshape()`)
โ No.
```py
import torch
# 1. ์ผ๋ฐ์ ์ธ ์ํฉ: ๋ ๋ค ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ณต์ ํ๋ค.
x = torch.arange(1, 7) # tensor([1, 2, 3, 4, 5, 6])
print(f"x๋ ์ฐ์์ ์ธ๊ฐ? {x.is_contiguous()}")
# x_view์ x_reshape ๋ชจ๋ x์ ๊ฐ์ ๋ฉ๋ชจ๋ฆฌ ๊ณต๊ฐ์ ๊ฐ๋ฆฌํจ๋ค.
x_view = x.view(2, 3)
x_reshape = x.reshape(2, 3)
# x_view์ ๊ฐ์ ๋ฐ๊พธ๋ฉด...
x_view[0, 0] = 99
# ์๋ณธ x์ x_reshape์ ๊ฐ๋ ๋ชจ๋ ๋ฐ๋๋ค!
print(f"์๋ณธ x: \n{x}")
print(f"x_reshape: \n{x_reshape}")
# ์๋ณธ x:
# tensor([99, 2, 3, 4, 5, 6])
# x_reshape:
# tensor([[99, 2, 3],
# [ 4, 5, 6]])
print("-" * 20)
# 2. ๋ฌธ์ ๊ฐ ๋๋ ์ํฉ: ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ฐ์์ ์ด์ง ์์ ๋
y = torch.arange(1, 7).reshape(2, 3)
y_transposed = y.T # .T๋ .transpose(0, 1)์ ๊ฐ์. ๋ชจ์์ (3,2)๊ฐ ๋จ
print(f"y_transposed๋ ์ฐ์์ ์ธ๊ฐ? {y_transposed.is_contiguous()}")
# y_transposed๋ ์ฐ์์ ์ธ๊ฐ? False <-- ๋ฐ๋ก ์ด ๋!
# 2-1. ์์น์ฃผ์์ view()๋ ์๋ฌ๋ฅผ ๋ฐ์์ํจ๋ค!
try:
y_view = y_transposed.view(2, 3)
except RuntimeError as e:
print(f"view() ์๋ฌ ๋ฐ์: {e}")
# view() ์๋ฌ ๋ฐ์: view size is not compatible with input tensor's size and stride...
# '...Use .reshape() or .contiguous().' ๋ผ๊ณ ์น์ ํ๊ฒ ์๋ ค์ฃผ๋ค!
# 2-2. ํด๊ฒฐ์ฌ reshape()๋ ์กฐ์ฉํ ๋ณต์ฌ๋ณธ์ ๋ง๋ค์ด ์ฒ๋ฆฌํ๋ค.
y_reshape = y_transposed.reshape(2, 3)
# y_reshape์ ๊ฐ์ ๋ฐ๊ฟ๋...
y_reshape[0, 0] = 777
# ์๋ณธ y_transposed์๋ ์๋ฌด๋ฐ ์ํฅ์ด ์๋ค! (๋ณต์ฌ๋ณธ์ด๋๊น)
print(f"\ny_reshape: \n{y_reshape}")
print(f"์๋ณธ y_transposed: \n{y_transposed}")
# y_reshape:
# tensor([[777, 2, 4],
# [ 3, 5, 6]])
# ์๋ณธ y_transposed:
# tensor([[1, 4],
# [2, 5],
# [3, 6]])
#### ๐ก ๊ทธ๋์ ๋ญ ์จ์ผ ํ ๊น? ์ผ๋ฐ์ ์ผ๋ก๋ ๊ทธ๋ฅ reshape()
์ ์ฐ๋ ๊ฒ ํธํ๊ณ ์์ . ๋๋ถ๋ถ์ ๊ฒฝ์ฐ ์ฐ๋ฆฌ๊ฐ ์ํ๋ ๋๋ก ๋์ํ๊ณ , ๊ตณ์ด ๋ฉ๋ชจ๋ฆฌ ์ฐ์์ฑ๊น์ง ์ ๊ฒฝ ์ฐ์ง ์์๋ ๋๋๊น.
view()
๋ ์ด๋ด ๋ ์ฌ์ฉ. ๋ด ์ฝ๋์์ โ์ฌ๊ธฐ๋ ๋ฐ๋์ ๋ฉ๋ชจ๋ฆฌ ๊ณต์ ๊ฐ ์ผ์ด๋์ผ ํ๊ณ , ๋ถํ์ํ ๋ณต์ฌ๊ฐ ์๊ธฐ๋ฉด ์ ๋๋ ์์ฃผ ์ค์ํ ๋ถ๋ถ์ด์ผ!โ๋ผ๊ณ ๋ช
์ํ๊ณ ์ถ์ ๋. ๋๋ view๊ฐ ๋ ๊ฑฐ๋ผ๊ณ 100% ํ์ ํ ๋ ์ฌ์ฉํ๋ฉด, ์๋์น ์์ ๋ณต์ฌ๋ฅผ ๋ง๋ ์์ ์ฅ์น๊ฐ ๋ ์ ์์.
This post is licensed under CC BY 4.0 by the author.