159 lines
3.6 KiB
Python
159 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
from pathlib import Path
|
|
import json
|
|
import time
|
|
|
|
from core.tools.base import BaseTool, ToolContext
|
|
from core.tools.registry import registry
|
|
from core.events import bus
|
|
from core.config import WORKSPACE_ROOT
|
|
|
|
|
|
class MemoryTool(BaseTool):
|
|
"""
|
|
Persistent memory store for agent experiences.
|
|
|
|
Stores:
|
|
- research results
|
|
- tool outputs
|
|
- agent decisions
|
|
- arbitrary notes
|
|
"""
|
|
|
|
name = "memory"
|
|
description = "Persistent memory storage and retrieval"
|
|
|
|
def __init__(self):
|
|
self.memory_file = Path(WORKSPACE_ROOT) / "memory_store.json"
|
|
self._ensure_file()
|
|
|
|
# =========================
|
|
# EXECUTE
|
|
# =========================
|
|
|
|
def execute(self, payload: dict[str, Any], ctx: ToolContext):
|
|
action = str(payload.get("action", "add")).strip()
|
|
|
|
bus.log(
|
|
"MEMORY",
|
|
"memory_execute",
|
|
"INFO",
|
|
{"action": action}
|
|
)
|
|
|
|
match action:
|
|
case "add":
|
|
return self.add(payload)
|
|
|
|
case "search":
|
|
return self.search(payload)
|
|
|
|
case "list":
|
|
return self.list_all()
|
|
|
|
case "clear":
|
|
return self.clear()
|
|
|
|
case _:
|
|
raise ValueError(f"Unknown memory action: {action}")
|
|
|
|
# =========================
|
|
# ADD MEMORY
|
|
# =========================
|
|
|
|
def add(self, payload: dict[str, Any]):
|
|
entry = payload.get("entry")
|
|
|
|
if not isinstance(entry, dict):
|
|
raise ValueError("entry must be dict")
|
|
|
|
memory = self._load()
|
|
|
|
record = {
|
|
"id": len(memory) + 1,
|
|
"timestamp": time.time(),
|
|
"entry": entry
|
|
}
|
|
|
|
memory.append(record)
|
|
self._save(memory)
|
|
|
|
return {
|
|
"status": "ok",
|
|
"stored": record
|
|
}
|
|
|
|
# =========================
|
|
# SEARCH MEMORY
|
|
# =========================
|
|
|
|
def search(self, payload: dict[str, Any]):
|
|
query = payload.get("query", "")
|
|
|
|
if not isinstance(query, str):
|
|
raise ValueError("query must be string")
|
|
|
|
memory = self._load()
|
|
|
|
results = []
|
|
|
|
for item in memory:
|
|
entry = item.get("entry", {})
|
|
text_blob = json.dumps(entry).lower()
|
|
|
|
if query.lower() in text_blob:
|
|
results.append(item)
|
|
|
|
return {
|
|
"query": query,
|
|
"results": results,
|
|
"count": len(results)
|
|
}
|
|
|
|
# =========================
|
|
# LIST ALL
|
|
# =========================
|
|
|
|
def list_all(self):
|
|
return {
|
|
"memory": self._load()
|
|
}
|
|
|
|
# =========================
|
|
# CLEAR MEMORY
|
|
# =========================
|
|
|
|
def clear(self):
|
|
self._save([])
|
|
return {"status": "cleared"}
|
|
|
|
# =========================
|
|
# STORAGE LAYER
|
|
# =========================
|
|
|
|
def _ensure_file(self):
|
|
self.memory_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not self.memory_file.exists():
|
|
self.memory_file.write_text("[]", encoding="utf-8")
|
|
|
|
def _load(self) -> list[dict[str, Any]]:
|
|
try:
|
|
return json.loads(self.memory_file.read_text(encoding="utf-8"))
|
|
except Exception:
|
|
return []
|
|
|
|
def _save(self, data: list[dict[str, Any]]):
|
|
self.memory_file.write_text(
|
|
json.dumps(data, indent=2),
|
|
encoding="utf-8"
|
|
)
|
|
|
|
|
|
# =========================
|
|
# REGISTER
|
|
# =========================
|
|
|
|
registry.register(MemoryTool()) |
