Post

[pytorch] view() vs. reshape() ๋น„๊ต

๐ŸŸฃ Intro

  • pytorch๋ฅผ ์‚ฌ์šฉํ•˜๋‹ค๋ณด๋ฉด tensor์˜ shape์„ ๋ฐ”๊ฟ”์•ผํ•˜๋Š” ์ž‘์—…์ด ๋งค์šฐ ๋นˆ๋ฒˆํ•˜๋‹ค. ์™œ๋ƒ๋ฉด ๋ชจ๋ธ์— forward๋ฅผ ํ•˜๋”๋ผ๋„ 2์ฐจ์› ๋ฐ์ดํ„ฐ์˜ ๊ฒฝ์šฐ๋Š” (B,C,W,H)์ฒ˜๋Ÿผ ์ฐจ์›์„ ์„ค์ •ํ•ด์ค˜์•ผ ๊ฐ ๋ ˆ์ด์–ด์—์„œ ํ•™์Šต์ด ๊ฐ€๋Šฅํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ๊ทธ๋Ÿด๋•Œ ๋งŽ์ด ์“ฐ๋Š” torch์˜ ๋ฉ”์„œ๋“œ๊ฐ€ reshape()๊ณผ view()์ด๋‹ค.

####

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
a = torch.randn(3, 2, 4)
b = a.reshape(-1)  # reshape() ๋ฉ”์„œ๋“œ๋ฅผ ์จ์„œ 1์ฐจ์›์œผ๋กœ ๋งŒ๋“ฆ
b[0] = 99999

a  # a์—๋„ ์˜ํ–ฅ์„ ์ค€๋‹ค

a = torch.randn(3, 2, 4)
b = a.view(-1)  # reshape() ๋ฉ”์„œ๋“œ๋ฅผ ์จ์„œ 1์ฐจ์›์œผ๋กœ ๋งŒ๋“ฆ
b[0] = 99999

a  # a์—๋„ ์˜ํ–ฅ์„ ์ค€๋‹ค



#### ๐ŸŸก ๊น๊นํ•œ ์›์น™์ฃผ์˜์ž (`view()`) vs. ์œตํ†ต์„ฑ ์žˆ๋Š” ํ•ด๊ฒฐ์‚ฌ (`reshape()`)

โŒ No.
```py
import torch
# 1. ์ผ๋ฐ˜์ ์ธ ์ƒํ™ฉ: ๋‘˜ ๋‹ค ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ณต์œ ํ•œ๋‹ค.
x = torch.arange(1, 7) # tensor([1, 2, 3, 4, 5, 6])
print(f"x๋Š” ์—ฐ์†์ ์ธ๊ฐ€? {x.is_contiguous()}")

# x_view์™€ x_reshape ๋ชจ๋‘ x์™€ ๊ฐ™์€ ๋ฉ”๋ชจ๋ฆฌ ๊ณต๊ฐ„์„ ๊ฐ€๋ฆฌํ‚จ๋‹ค.
x_view = x.view(2, 3)
x_reshape = x.reshape(2, 3)

# x_view์˜ ๊ฐ’์„ ๋ฐ”๊พธ๋ฉด...
x_view[0, 0] = 99

# ์›๋ณธ x์™€ x_reshape์˜ ๊ฐ’๋„ ๋ชจ๋‘ ๋ฐ”๋€๋‹ค!
print(f"์›๋ณธ x: \n{x}")
print(f"x_reshape: \n{x_reshape}")
# ์›๋ณธ x:
# tensor([99,  2,  3,  4,  5,  6])
# x_reshape:
# tensor([[99,  2,  3],
#         [ 4,  5,  6]])
print("-" * 20)


# 2. ๋ฌธ์ œ๊ฐ€ ๋˜๋Š” ์ƒํ™ฉ: ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์—ฐ์†์ ์ด์ง€ ์•Š์„ ๋•Œ
y = torch.arange(1, 7).reshape(2, 3)
y_transposed = y.T # .T๋Š” .transpose(0, 1)์™€ ๊ฐ™์Œ. ๋ชจ์–‘์€ (3,2)๊ฐ€ ๋จ
print(f"y_transposed๋Š” ์—ฐ์†์ ์ธ๊ฐ€? {y_transposed.is_contiguous()}")
# y_transposed๋Š” ์—ฐ์†์ ์ธ๊ฐ€? False  <-- ๋ฐ”๋กœ ์ด ๋•Œ!

# 2-1. ์›์น™์ฃผ์˜์ž view()๋Š” ์—๋Ÿฌ๋ฅผ ๋ฐœ์ƒ์‹œํ‚จ๋‹ค!
try:
    y_view = y_transposed.view(2, 3)
except RuntimeError as e:
    print(f"view() ์—๋Ÿฌ ๋ฐœ์ƒ: {e}")
# view() ์—๋Ÿฌ ๋ฐœ์ƒ: view size is not compatible with input tensor's size and stride...
# '...Use .reshape() or .contiguous().' ๋ผ๊ณ  ์นœ์ ˆํ•˜๊ฒŒ ์•Œ๋ ค์ฃผ๋„ค!

# 2-2. ํ•ด๊ฒฐ์‚ฌ reshape()๋Š” ์กฐ์šฉํžˆ ๋ณต์‚ฌ๋ณธ์„ ๋งŒ๋“ค์–ด ์ฒ˜๋ฆฌํ•œ๋‹ค.
y_reshape = y_transposed.reshape(2, 3)

# y_reshape์˜ ๊ฐ’์„ ๋ฐ”๊ฟ”๋„...
y_reshape[0, 0] = 777

# ์›๋ณธ y_transposed์—๋Š” ์•„๋ฌด๋Ÿฐ ์˜ํ–ฅ์ด ์—†๋‹ค! (๋ณต์‚ฌ๋ณธ์ด๋‹ˆ๊นŒ)
print(f"\ny_reshape: \n{y_reshape}")
print(f"์›๋ณธ y_transposed: \n{y_transposed}")
# y_reshape:
# tensor([[777,   2,   4],
#         [  3,   5,   6]])
# ์›๋ณธ y_transposed:
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])

#### ๐ŸŸก ๊ทธ๋ž˜์„œ ๋ญ˜ ์จ์•ผ ํ• ๊นŒ? ์ผ๋ฐ˜์ ์œผ๋กœ๋Š” ๊ทธ๋ƒฅ reshape()์„ ์“ฐ๋Š” ๊ฒŒ ํŽธํ•˜๊ณ  ์•ˆ์ „. ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ ์šฐ๋ฆฌ๊ฐ€ ์›ํ•˜๋Š” ๋Œ€๋กœ ๋™์ž‘ํ•˜๊ณ , ๊ตณ์ด ๋ฉ”๋ชจ๋ฆฌ ์—ฐ์†์„ฑ๊นŒ์ง€ ์‹ ๊ฒฝ ์“ฐ์ง€ ์•Š์•„๋„ ๋˜๋‹ˆ๊นŒ.

view()๋Š” ์ด๋Ÿด ๋•Œ ์‚ฌ์šฉ. ๋‚ด ์ฝ”๋“œ์—์„œ โ€œ์—ฌ๊ธฐ๋Š” ๋ฐ˜๋“œ์‹œ ๋ฉ”๋ชจ๋ฆฌ ๊ณต์œ ๊ฐ€ ์ผ์–ด๋‚˜์•ผ ํ•˜๊ณ , ๋ถˆํ•„์š”ํ•œ ๋ณต์‚ฌ๊ฐ€ ์ƒ๊ธฐ๋ฉด ์•ˆ ๋˜๋Š” ์•„์ฃผ ์ค‘์š”ํ•œ ๋ถ€๋ถ„์ด์•ผ!โ€๋ผ๊ณ  ๋ช…์‹œํ•˜๊ณ  ์‹ถ์„ ๋•Œ. ๋˜๋Š” view๊ฐ€ ๋  ๊ฑฐ๋ผ๊ณ  100% ํ™•์‹ ํ•  ๋•Œ ์‚ฌ์šฉํ•˜๋ฉด, ์˜๋„์น˜ ์•Š์€ ๋ณต์‚ฌ๋ฅผ ๋ง‰๋Š” ์•ˆ์ „์žฅ์น˜๊ฐ€ ๋  ์ˆ˜ ์žˆ์Œ.

This post is licensed under CC BY 4.0 by the author.

ยฉ 2025 Soohyun Jeon โญ

๐ŸŒฑ Mostly to remember, sometimes to understand.