import hashlib
import json
from typing import List, Union
from pydantic import BaseModel
[docs]
class RLTuple(BaseModel):
observation: Union[list, None]
action: list
reward: float
terminal: bool
timeout: bool
[docs]
class TupleItem(BaseModel):
run_id: str
tuple: RLTuple
[docs]
class RecordedRun(BaseModel):
run_id: str
scenario_id: str
config: dict
[docs]
class TupleDB:
"""
Base class for tuple database.
"""
def __init__(self) -> None:
self.current_run_id: str = None
def create_run(self, scenario_id: str, config: dict, seed: int):
self.current_run_id = hashlib.sha256(
(json.dumps(config, sort_keys=True) + scenario_id + str(seed)).encode()
).hexdigest()
self._create_run(run=RecordedRun(run_id=self.current_run_id, scenario_id=scenario_id, config=config))
def record_tuples(self, tuples: List[RLTuple]):
self._record_tuples(tuples=[TupleItem(run_id=self.current_run_id, tuple=t) for t in tuples])
def list_runs(self) -> List[RecordedRun]:
raise NotImplementedError()
def get_tuples(self, filters: dict = {}) -> List[TupleItem]:
raise NotImplementedError()
def _create_run(self, run: RecordedRun):
raise NotImplementedError()
def _record_tuples(self, tuples: List[TupleItem]):
raise NotImplementedError()
[docs]
class MongoTupleDB(TupleDB):
def __init__(self, db_url: str = "mongodb://localhost:27017/", db_name: str = "tuple_db"):
from pymongo import MongoClient
self.db = MongoClient(db_url)[db_name]
def list_runs(self) -> List[RecordedRun]:
runs = self.db.runs.find()
return [RecordedRun(**run) for run in runs]
def get_tuples(self, filters: dict = {}) -> List[TupleItem]:
tuples = self.db.tuples.find(filters)
return [TupleItem(**t) for t in tuples]
def _create_run(self, run: RecordedRun):
self.db.runs.insert_one(run.model_dump())
def _record_tuples(self, tuples: List[TupleItem]):
self.db.tuples.insert_many([t.model_dump() for t in tuples])
[docs]
class LocalFileTupleDB(TupleDB):
"""
Here we store everything in local files.
They are written and read line by line.
"""
def __init__(self, base_db_name: str = "."):
self.db = base_db_name
self.runs_file = base_db_name + "_runs.txt"
self.tuples_file = base_db_name + "_tuples.txt"
def list_runs(self) -> List[RecordedRun]:
with open(self.runs_file, 'r') as file:
runs = file.readlines()
return [RecordedRun(**json.loads(run)) for run in runs]
def get_tuples(self, filters: dict = {}) -> List[TupleItem]:
with open(self.tuples_file, 'r') as file:
tuples = file.readlines()
return [TupleItem(**json.loads(t)) for t in tuples]
def _create_run(self, run: RecordedRun):
with open(self.runs_file, 'a') as file:
file.write(json.dumps(run.model_dump()) + '\n')
def _record_tuples(self, tuples: List[TupleItem]):
with open(self.tuples_file, 'a') as file:
for t in tuples:
file.write(json.dumps(t.model_dump()) + '\n')