158 lines
4.2 KiB
Python
158 lines
4.2 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
import requests
|
|
|
|
from core.tools.base import BaseTool, ToolContext
|
|
from core.tools.registry import registry
|
|
from core.events import bus
|
|
from core.config import OLLAMA_URL
|
|
|
|
|
|
class OllamaTool(BaseTool):
|
|
"""
|
|
Local LLM interface via Ollama.
|
|
|
|
Enables the agent to call local models for reasoning,
|
|
summarization, and transformation tasks.
|
|
"""
|
|
|
|
name = "ollama"
|
|
description = "Local LLM inference via Ollama"
|
|
|
|
# =========================================================
|
|
# EXECUTE
|
|
# =========================================================
|
|
|
|
def execute(self, payload: dict[str, Any], ctx: ToolContext):
|
|
action = str(payload.get("action", "generate")).strip()
|
|
|
|
bus.log(
|
|
"OLLAMA",
|
|
"ollama_execute",
|
|
"INFO",
|
|
{"action": action}
|
|
)
|
|
|
|
match action:
|
|
case "generate":
|
|
return self.generate(payload, ctx)
|
|
|
|
case "chat":
|
|
return self.chat(payload, ctx)
|
|
|
|
case "models":
|
|
return self.list_models()
|
|
|
|
case _:
|
|
raise ValueError(f"Unknown ollama action: {action}")
|
|
|
|
# =========================================================
|
|
# GENERATE (single prompt)
|
|
# =========================================================
|
|
|
|
def generate(self, payload: dict[str, Any], ctx: ToolContext):
|
|
model = payload.get("model", "llama3")
|
|
prompt = payload.get("prompt")
|
|
|
|
if not isinstance(prompt, str):
|
|
raise ValueError("prompt must be string")
|
|
|
|
url = f"{OLLAMA_URL}/api/generate"
|
|
|
|
data = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"stream": False
|
|
}
|
|
|
|
if ctx.dry_run:
|
|
return {
|
|
"dry_run": True,
|
|
"model": model,
|
|
"prompt_preview": prompt[:200]
|
|
}
|
|
|
|
try:
|
|
response = requests.post(url, json=data, timeout=(5, 120))
|
|
response.raise_for_status()
|
|
|
|
return {
|
|
"model": model,
|
|
"response": response.json().get("response", ""),
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return {
|
|
"status": "error",
|
|
"error": str(e)
|
|
}
|
|
|
|
# =========================================================
|
|
# CHAT (multi-message style)
|
|
# =========================================================
|
|
|
|
def chat(self, payload: dict[str, Any], ctx: ToolContext):
|
|
model = payload.get("model", "llama3")
|
|
messages = payload.get("messages")
|
|
|
|
if not isinstance(messages, list):
|
|
raise ValueError("messages must be list[dict]")
|
|
|
|
url = f"{OLLAMA_URL}/api/chat"
|
|
|
|
data = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": False
|
|
}
|
|
|
|
if ctx.dry_run:
|
|
return {
|
|
"dry_run": True,
|
|
"model": model,
|
|
"message_count": len(messages)
|
|
}
|
|
|
|
try:
|
|
response = requests.post(url, json=data, timeout=(5, 180))
|
|
response.raise_for_status()
|
|
|
|
return {
|
|
"model": model,
|
|
"response": response.json().get("message", {}).get("content", "")
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return {
|
|
"status": "error",
|
|
"error": str(e)
|
|
}
|
|
|
|
# =========================================================
|
|
# LIST MODELS
|
|
# =========================================================
|
|
|
|
def list_models(self):
|
|
url = f"{OLLAMA_URL}/api/tags"
|
|
|
|
try:
|
|
response = requests.get(url, timeout=(3, 10))
|
|
response.raise_for_status()
|
|
|
|
return {
|
|
"models": response.json().get("models", [])
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return {
|
|
"status": "error",
|
|
"error": str(e)
|
|
}
|
|
|
|
|
|
# =========================================================
|
|
# REGISTER TOOL
|
|
# =========================================================
|
|
|
|
registry.register(OllamaTool()) |
