imwithye commited on
Commit
6d11af2
·
1 Parent(s): 90d394b

add value network

Browse files
Files changed (2) hide show
  1. rlcube/cube2.ipynb +3 -3
  2. rlcube/models/models.py +42 -14
rlcube/cube2.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 3,
6
  "id": "624c83c1",
7
  "metadata": {},
8
  "outputs": [],
@@ -37,7 +37,7 @@
37
  },
38
  {
39
  "cell_type": "code",
40
- "execution_count": 4,
41
  "id": "7a81c85a",
42
  "metadata": {},
43
  "outputs": [
@@ -57,7 +57,7 @@
57
  "source": [
58
  "env = RewardWrapper(Cube2())\n",
59
  "obs, _ = env.reset()\n",
60
- "print(env.state())"
61
  ]
62
  }
63
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "624c83c1",
7
  "metadata": {},
8
  "outputs": [],
 
37
  },
38
  {
39
  "cell_type": "code",
40
+ "execution_count": null,
41
  "id": "7a81c85a",
42
  "metadata": {},
43
  "outputs": [
 
57
  "source": [
58
  "env = RewardWrapper(Cube2())\n",
59
  "obs, _ = env.reset()\n",
60
+ "print(obs.shape)"
61
  ]
62
  }
63
  ],
rlcube/models/models.py CHANGED
@@ -5,27 +5,55 @@ import torch
5
 
6
 
7
  class ResidualBlock(nn.Module):
8
- def __init__(self, input_dim, hidden_dim):
9
- super(ResidualBlock, self).__init__()
10
- self.bn1 = nn.BatchNorm1d(input_dim)
11
- self.fc1 = nn.Linear(input_dim, hidden_dim)
12
- self.bn2 = nn.BatchNorm1d(hidden_dim)
13
- self.fc2 = nn.Linear(hidden_dim, input_dim)
14
 
15
  def forward(self, x):
16
  residual = x
17
- out = self.bn1(x)
18
- out = F.relu(out)
19
- out = self.fc1(out)
20
- out = self.bn2(out)
21
- out = F.relu(out)
22
- out = self.fc2(out)
23
  out = out + residual
24
  return out
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if __name__ == "__main__":
28
  print("Testing ResidualBlock, input_dim=24, hidden_dim=128")
29
- x = torch.randn(4, 24)
 
 
 
 
 
 
30
  print("Input shape:", x.shape)
31
- print("Output shape:", ResidualBlock(24, 128)(x).shape)
 
 
5
 
6
 
7
  class ResidualBlock(nn.Module):
8
+ def __init__(self, dim, hidden_dim):
9
+ super().__init__()
10
+ self.ln1 = nn.LayerNorm(dim)
11
+ self.fc1 = nn.Linear(dim, hidden_dim)
12
+ self.ln2 = nn.LayerNorm(hidden_dim)
13
+ self.fc2 = nn.Linear(hidden_dim, dim)
14
 
15
  def forward(self, x):
16
  residual = x
17
+ out = self.fc1(F.relu(self.ln1(x)))
18
+ out = self.fc2(F.relu(self.ln2(out)))
 
 
 
 
19
  out = out + residual
20
  return out
21
 
22
 
23
+ class Cube2VNetwork(L.LightningModule):
24
+ def __init__(self, hidden_dim=512, num_residual_blocks=4):
25
+ super().__init__()
26
+ input_dim = 24 * 6
27
+ self.fc_in = nn.Linear(input_dim, hidden_dim)
28
+
29
+ self.residual_blocks = nn.ModuleList(
30
+ [
31
+ ResidualBlock(hidden_dim, hidden_dim * 2)
32
+ for _ in range(num_residual_blocks)
33
+ ]
34
+ )
35
+
36
+ self.fc_out = nn.Linear(hidden_dim, 1)
37
+
38
+ def forward(self, x):
39
+ batch_size = x.size(0)
40
+ x = x.view(batch_size, -1)
41
+ out = F.relu(self.fc_in(x))
42
+ for block in self.residual_blocks:
43
+ out = block(out)
44
+ out = self.fc_out(out)
45
+ return out
46
+
47
+
48
  if __name__ == "__main__":
49
  print("Testing ResidualBlock, input_dim=24, hidden_dim=128")
50
+ x = torch.randn(4, 24, 6)
51
+ print("Input shape:", x.shape)
52
+ print("Output shape:", ResidualBlock(6, 128)(x).shape)
53
+
54
+ print("Testing Cube2VNetwork, input_dim=24, num_residual_blocks=4")
55
+ x = torch.randn(4, 24, 6)
56
+ net = Cube2VNetwork()
57
  print("Input shape:", x.shape)
58
+ print("Output shape:", net(x).shape)
59
+ print("Number of parameters:", sum(p.numel() for p in net.parameters()))