imwithye commited on
Commit
5529a00
·
1 Parent(s): 4bdbfb1

add python api

Browse files
Files changed (4) hide show
  1. rlcube/cube2.ipynb +47 -4
  2. rlcube/main.py +23 -0
  3. rlcube/pyproject.toml +2 -1
  4. rlcube/uv.lock +0 -0
rlcube/cube2.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 55,
6
  "id": "dff864f2",
7
  "metadata": {},
8
  "outputs": [],
@@ -240,7 +240,7 @@
240
  },
241
  {
242
  "cell_type": "code",
243
- "execution_count": 56,
244
  "id": "624c83c1",
245
  "metadata": {},
246
  "outputs": [],
@@ -269,6 +269,49 @@
269
  " return obs, reward, terminated, truncated, _"
270
  ]
271
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  {
273
  "cell_type": "code",
274
  "execution_count": null,
@@ -516,7 +559,7 @@
516
  ],
517
  "metadata": {
518
  "kernelspec": {
519
- "display_name": "dev",
520
  "language": "python",
521
  "name": "python3"
522
  },
@@ -530,7 +573,7 @@
530
  "name": "python",
531
  "nbconvert_exporter": "python",
532
  "pygments_lexer": "ipython3",
533
- "version": "3.13.5"
534
  }
535
  },
536
  "nbformat": 4,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "dff864f2",
7
  "metadata": {},
8
  "outputs": [],
 
240
  },
241
  {
242
  "cell_type": "code",
243
+ "execution_count": 2,
244
  "id": "624c83c1",
245
  "metadata": {},
246
  "outputs": [],
 
269
  " return obs, reward, terminated, truncated, _"
270
  ]
271
  },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "id": "7a81c85a",
276
+ "metadata": {},
277
+ "outputs": [
278
+ {
279
+ "name": "stdout",
280
+ "output_type": "stream",
281
+ "text": [
282
+ "[[0. 0. 1. 0. 0. 0.]\n",
283
+ " [1. 0. 0. 0. 0. 0.]\n",
284
+ " [0. 0. 0. 0. 0. 1.]\n",
285
+ " [0. 0. 0. 0. 0. 1.]\n",
286
+ " [0. 0. 0. 1. 0. 0.]\n",
287
+ " [0. 1. 0. 0. 0. 0.]\n",
288
+ " [0. 0. 0. 0. 1. 0.]\n",
289
+ " [0. 0. 0. 0. 1. 0.]\n",
290
+ " [0. 0. 1. 0. 0. 0.]\n",
291
+ " [0. 1. 0. 0. 0. 0.]\n",
292
+ " [0. 0. 1. 0. 0. 0.]\n",
293
+ " [0. 1. 0. 0. 0. 0.]\n",
294
+ " [1. 0. 0. 0. 0. 0.]\n",
295
+ " [1. 0. 0. 0. 0. 0.]\n",
296
+ " [0. 0. 0. 1. 0. 0.]\n",
297
+ " [0. 0. 0. 1. 0. 0.]\n",
298
+ " [0. 0. 0. 0. 1. 0.]\n",
299
+ " [0. 0. 0. 0. 1. 0.]\n",
300
+ " [1. 0. 0. 0. 0. 0.]\n",
301
+ " [0. 0. 1. 0. 0. 0.]\n",
302
+ " [0. 0. 0. 0. 0. 1.]\n",
303
+ " [0. 0. 0. 0. 0. 1.]\n",
304
+ " [0. 1. 0. 0. 0. 0.]\n",
305
+ " [0. 0. 0. 1. 0. 0.]]\n"
306
+ ]
307
+ }
308
+ ],
309
+ "source": [
310
+ "env = RewardWrapper(Cube2())\n",
311
+ "obs, _ = env.reset()\n",
312
+ "print(env.state())"
313
+ ]
314
+ },
315
  {
316
  "cell_type": "code",
317
  "execution_count": null,
 
559
  ],
560
  "metadata": {
561
  "kernelspec": {
562
+ "display_name": ".venv",
563
  "language": "python",
564
  "name": "python3"
565
  },
 
573
  "name": "python",
574
  "nbconvert_exporter": "python",
575
  "pygments_lexer": "ipython3",
576
+ "version": "3.12.11"
577
  }
578
  },
579
  "nbformat": 4,
rlcube/main.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from fastapi import HTTPException
5
+
6
+ app = FastAPI()
7
+
8
+
9
+ class StateArgs(BaseModel):
10
+ state: List[List[int]]
11
+
12
+
13
+ @app.post("/solve")
14
+ def solve(body: StateArgs):
15
+ state = body.state
16
+ if not (
17
+ isinstance(state, list)
18
+ and len(state) == 6
19
+ and all(isinstance(row, list) and len(row) == 4 for row in state)
20
+ ):
21
+ raise HTTPException(status_code=400, detail="state must be a 6x4 matrix")
22
+
23
+ return {"steps": [1, 2, 1, 1]}
rlcube/pyproject.toml CHANGED
@@ -5,7 +5,8 @@ description = "Reinforcement Learning for Rubik's Cube"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
 
8
  "gymnasium>=1.2.0",
 
9
  "numpy>=2.3.2",
10
- "stable-baselines3>=2.7.0",
11
  ]
 
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
+ "fastapi[standard]>=0.116.2",
9
  "gymnasium>=1.2.0",
10
+ "ipykernel>=6.30.1",
11
  "numpy>=2.3.2",
 
12
  ]
rlcube/uv.lock CHANGED
The diff for this file is too large to render. See raw diff