imwithye commited on
Commit
11d757e
·
1 Parent(s): 6d11af2
rlcube/{envs → rlcube/envs}/cube2.py RENAMED
File without changes
rlcube/{models → rlcube/models}/models.py RENAMED
@@ -1,4 +1,3 @@
1
- import lightning as L
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import torch
@@ -20,7 +19,7 @@ class ResidualBlock(nn.Module):
20
  return out
21
 
22
 
23
- class Cube2VNetwork(L.LightningModule):
24
  def __init__(self, hidden_dim=512, num_residual_blocks=4):
25
  super().__init__()
26
  input_dim = 24 * 6
@@ -53,7 +52,7 @@ if __name__ == "__main__":
53
 
54
  print("Testing Cube2VNetwork, input_dim=24, num_residual_blocks=4")
55
  x = torch.randn(4, 24, 6)
56
- net = Cube2VNetwork()
57
  print("Input shape:", x.shape)
58
  print("Output shape:", net(x).shape)
59
  print("Number of parameters:", sum(p.numel() for p in net.parameters()))
 
 
1
  import torch.nn as nn
2
  import torch.nn.functional as F
3
  import torch
 
19
  return out
20
 
21
 
22
+ class DNN(nn.Module):
23
  def __init__(self, hidden_dim=512, num_residual_blocks=4):
24
  super().__init__()
25
  input_dim = 24 * 6
 
52
 
53
  print("Testing Cube2VNetwork, input_dim=24, num_residual_blocks=4")
54
  x = torch.randn(4, 24, 6)
55
+ net = DNN()
56
  print("Input shape:", x.shape)
57
  print("Output shape:", net(x).shape)
58
  print("Number of parameters:", sum(p.numel() for p in net.parameters()))
rlcube/rlcube/train/train.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
4
+
5
+ from rlcube.envs.cube2 import Cube2
6
+
7
+
8
+ def generate_train_data(num_envs: int = 1000, num_steps: int = 20):
9
+ for _ in range(num_envs):
10
+ env = Cube2()
11
+ obs, _ = env.reset()
12
+
13
+ print(obs)
14
+ break
15
+
16
+
17
+ if __name__ == "__main__":
18
+ generate_train_data()