Spaces:
Running
Running
add value network
Browse files- rlcube/cube2.ipynb +3 -3
- rlcube/models/models.py +42 -14
rlcube/cube2.ipynb
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "624c83c1",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
|
@@ -37,7 +37,7 @@
|
|
| 37 |
},
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
| 40 |
-
"execution_count":
|
| 41 |
"id": "7a81c85a",
|
| 42 |
"metadata": {},
|
| 43 |
"outputs": [
|
|
@@ -57,7 +57,7 @@
|
|
| 57 |
"source": [
|
| 58 |
"env = RewardWrapper(Cube2())\n",
|
| 59 |
"obs, _ = env.reset()\n",
|
| 60 |
-
"print(
|
| 61 |
]
|
| 62 |
}
|
| 63 |
],
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
"id": "624c83c1",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
|
|
|
| 37 |
},
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
"id": "7a81c85a",
|
| 42 |
"metadata": {},
|
| 43 |
"outputs": [
|
|
|
|
| 57 |
"source": [
|
| 58 |
"env = RewardWrapper(Cube2())\n",
|
| 59 |
"obs, _ = env.reset()\n",
|
| 60 |
+
"print(obs.shape)"
|
| 61 |
]
|
| 62 |
}
|
| 63 |
],
|
rlcube/models/models.py
CHANGED
|
@@ -5,27 +5,55 @@ import torch
|
|
| 5 |
|
| 6 |
|
| 7 |
class ResidualBlock(nn.Module):
|
| 8 |
-
def __init__(self,
|
| 9 |
-
super(
|
| 10 |
-
self.
|
| 11 |
-
self.fc1 = nn.Linear(
|
| 12 |
-
self.
|
| 13 |
-
self.fc2 = nn.Linear(hidden_dim,
|
| 14 |
|
| 15 |
def forward(self, x):
|
| 16 |
residual = x
|
| 17 |
-
out = self.
|
| 18 |
-
out = F.relu(out)
|
| 19 |
-
out = self.fc1(out)
|
| 20 |
-
out = self.bn2(out)
|
| 21 |
-
out = F.relu(out)
|
| 22 |
-
out = self.fc2(out)
|
| 23 |
out = out + residual
|
| 24 |
return out
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
if __name__ == "__main__":
|
| 28 |
print("Testing ResidualBlock, input_dim=24, hidden_dim=128")
|
| 29 |
-
x = torch.randn(4, 24)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
print("Input shape:", x.shape)
|
| 31 |
-
print("Output shape:",
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class ResidualBlock(nn.Module):
|
| 8 |
+
def __init__(self, dim, hidden_dim):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.ln1 = nn.LayerNorm(dim)
|
| 11 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
| 12 |
+
self.ln2 = nn.LayerNorm(hidden_dim)
|
| 13 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
| 14 |
|
| 15 |
def forward(self, x):
|
| 16 |
residual = x
|
| 17 |
+
out = self.fc1(F.relu(self.ln1(x)))
|
| 18 |
+
out = self.fc2(F.relu(self.ln2(out)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
out = out + residual
|
| 20 |
return out
|
| 21 |
|
| 22 |
|
| 23 |
+
class Cube2VNetwork(L.LightningModule):
|
| 24 |
+
def __init__(self, hidden_dim=512, num_residual_blocks=4):
|
| 25 |
+
super().__init__()
|
| 26 |
+
input_dim = 24 * 6
|
| 27 |
+
self.fc_in = nn.Linear(input_dim, hidden_dim)
|
| 28 |
+
|
| 29 |
+
self.residual_blocks = nn.ModuleList(
|
| 30 |
+
[
|
| 31 |
+
ResidualBlock(hidden_dim, hidden_dim * 2)
|
| 32 |
+
for _ in range(num_residual_blocks)
|
| 33 |
+
]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.fc_out = nn.Linear(hidden_dim, 1)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
batch_size = x.size(0)
|
| 40 |
+
x = x.view(batch_size, -1)
|
| 41 |
+
out = F.relu(self.fc_in(x))
|
| 42 |
+
for block in self.residual_blocks:
|
| 43 |
+
out = block(out)
|
| 44 |
+
out = self.fc_out(out)
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
|
| 48 |
if __name__ == "__main__":
|
| 49 |
print("Testing ResidualBlock, input_dim=24, hidden_dim=128")
|
| 50 |
+
x = torch.randn(4, 24, 6)
|
| 51 |
+
print("Input shape:", x.shape)
|
| 52 |
+
print("Output shape:", ResidualBlock(6, 128)(x).shape)
|
| 53 |
+
|
| 54 |
+
print("Testing Cube2VNetwork, input_dim=24, num_residual_blocks=4")
|
| 55 |
+
x = torch.randn(4, 24, 6)
|
| 56 |
+
net = Cube2VNetwork()
|
| 57 |
print("Input shape:", x.shape)
|
| 58 |
+
print("Output shape:", net(x).shape)
|
| 59 |
+
print("Number of parameters:", sum(p.numel() for p in net.parameters()))
|