Skip to main content

Stateful Agents

Learn how to add memory and persistent state to your AI agents, enabling them to learn from past interactions and make smarter decisions over time.

Stateless vs Stateful

Stateless Agent

Each interaction is independent. No memory of previous conversations or decisions.

# Every call is fresh - no context
result1 = await agent.ainvoke({"input": "Buy ETH"})
result2 = await agent.ainvoke({"input": "What did I just buy?"})
# Agent doesn't know about the previous buy

Stateful Agent

Maintains context across interactions. Remembers previous decisions and their outcomes.

# Agent remembers context
result1 = await agent.ainvoke({"input": "Buy ETH"})
result2 = await agent.ainvoke({"input": "What did I just buy?"})
# Agent knows: "You bought ETH in the previous transaction"

Memory Types

1. Conversation Memory

Remember recent messages in the conversation:

from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationChain
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

llm = ChatOpenAI(model="gpt-4", temperature=0)

memory = ConversationBufferMemory(
return_messages=True,
memory_key="history",
)

prompt = ChatPromptTemplate.from_messages([
("system", "You are a DeFi assistant."),
MessagesPlaceholder(variable_name="history"),
("human", "{input}"),
])

chain = ConversationChain(
llm=llm,
memory=memory,
prompt=prompt,
)

# Conversations now have context
await chain.acall({"input": "My vault address is 0x123..."})
await chain.acall({"input": "What is my vault address?"})
# Agent remembers: 0x123...

2. Summary Memory

Compress long conversations into summaries:

from langchain.memory import ConversationSummaryMemory

memory = ConversationSummaryMemory(
llm=llm,
memory_key="summary",
)

# Long conversations get summarized
# "User created vault, executed 3 swaps, current balance is 5 ETH"

3. Vault Memory (ZeroQuant)

Track vault-specific state:

from zeroquant.langchain.memory import VaultMemory
from datetime import datetime

memory = VaultMemory()

# Track vault state
memory.add_vault("0x123...", {
"created_at": datetime.now().timestamp(),
"owner": "0xabc...",
})

# Track operations
memory.add_operation("0x123...", {
"type": "SWAP",
"token_in": "WETH",
"token_out": "USDC",
"amount_in": "1.0",
"timestamp": datetime.now().timestamp(),
"tx_hash": "0x...",
})

# Query history
history = memory.get_operations("0x123...", limit=10)

Building a Stateful Trading Agent

Step 1: Define State Schema

from dataclasses import dataclass, field
from typing import Optional
from enum import Enum


class TradeAction(Enum):
BUY = "BUY"
SELL = "SELL"


@dataclass
class Trade:
timestamp: float
action: TradeAction
token: str
amount: str
price: float
pnl: Optional[float] = None


@dataclass
class Position:
amount: int
entry_price: float
current_price: float
unrealized_pnl: float


@dataclass
class AgentState:
# Vault info
vault_address: str
vault_balance: int = 0

# Trading history
trades: list[Trade] = field(default_factory=list)

# Performance metrics
total_trades: int = 0
win_rate: float = 0.0
total_pnl: float = 0.0

# Current positions
positions: dict[str, Position] = field(default_factory=dict)

# Learning
successful_patterns: list[str] = field(default_factory=list)
failed_patterns: list[str] = field(default_factory=list)

Step 2: Implement State Manager

import os
import json
import redis.asyncio as redis


class AgentStateManager:
def __init__(self, vault_address: str):
self.redis = redis.from_url(os.getenv("REDIS_URL"))
self.state = self._initialize_state(vault_address)

def _initialize_state(self, vault_address: str) -> AgentState:
return AgentState(vault_address=vault_address)

async def load(self) -> None:
data = await self.redis.get(f"agent:{self.state.vault_address}")
if data:
parsed = json.loads(data)
self.state = AgentState(
vault_address=parsed["vault_address"],
vault_balance=int(parsed["vault_balance"]),
trades=[Trade(**t) for t in parsed["trades"]],
total_trades=parsed["total_trades"],
win_rate=parsed["win_rate"],
total_pnl=parsed["total_pnl"],
positions={k: Position(**v) for k, v in parsed["positions"].items()},
successful_patterns=parsed["successful_patterns"],
failed_patterns=parsed["failed_patterns"],
)

async def save(self) -> None:
serialized = {
"vault_address": self.state.vault_address,
"vault_balance": str(self.state.vault_balance),
"trades": [
{
"timestamp": t.timestamp,
"action": t.action.value,
"token": t.token,
"amount": t.amount,
"price": t.price,
"pnl": t.pnl,
}
for t in self.state.trades
],
"total_trades": self.state.total_trades,
"win_rate": self.state.win_rate,
"total_pnl": self.state.total_pnl,
"positions": {
k: {
"amount": v.amount,
"entry_price": v.entry_price,
"current_price": v.current_price,
"unrealized_pnl": v.unrealized_pnl,
}
for k, v in self.state.positions.items()
},
"successful_patterns": self.state.successful_patterns,
"failed_patterns": self.state.failed_patterns,
}
await self.redis.set(
f"agent:{self.state.vault_address}",
json.dumps(serialized),
)

def record_trade(self, trade: Trade) -> None:
self.state.trades.append(trade)
self.state.total_trades += 1

# Update win rate
wins = sum(1 for t in self.state.trades if (t.pnl or 0) > 0)
self.state.win_rate = (wins / self.state.total_trades) * 100

# Update total PnL
if trade.pnl:
self.state.total_pnl += trade.pnl

def get_performance_summary(self) -> str:
recent_trades = ", ".join(
f"{t.action.value} {t.token} @ ${t.price}"
for t in self.state.trades[-5:]
)
return f"""Performance Summary:
- Total Trades: {self.state.total_trades}
- Win Rate: {self.state.win_rate:.1f}%
- Total PnL: ${self.state.total_pnl:.2f}
- Recent Trades: {recent_trades}"""

def should_avoid_pattern(self, pattern: str) -> bool:
return pattern in self.state.failed_patterns

def learn_from_trade(self, pattern: str, success: bool) -> None:
if success:
if pattern not in self.state.successful_patterns:
self.state.successful_patterns.append(pattern)
# Remove from failed if it was there
self.state.failed_patterns = [
p for p in self.state.failed_patterns if p != pattern
]
else:
if pattern not in self.state.failed_patterns:
self.state.failed_patterns.append(pattern)

Step 3: Stateful Decision Engine

import json
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate


class StatefulDecisionEngine:
def __init__(self, state_manager: AgentStateManager):
self.state_manager = state_manager
self.llm = ChatOpenAI(model="gpt-4", temperature=0)

async def analyze(self, market_data: dict) -> dict:
# Include historical context in the prompt
prompt = ChatPromptTemplate.from_template("""
You are a trading agent with the following history:

{performance_summary}

Patterns to AVOID (previously failed):
{failed_patterns}

Patterns that worked well:
{successful_patterns}

Current Market Data:
{market_data}

Based on your past performance and current market conditions, provide a recommendation.
Consider:
1. Your historical win rate on similar setups
2. Patterns that have failed before
3. Current market conditions

Respond in JSON format:
{{
"action": "BUY" | "SELL" | "HOLD",
"confidence": 0-100,
"reasoning": "explanation including historical context",
"pattern": "description of this setup for learning"
}}
""")

response = await self.llm.ainvoke(
prompt.format(
performance_summary=self.state_manager.get_performance_summary(),
failed_patterns=", ".join(self.state_manager.state.failed_patterns) or "None",
successful_patterns=", ".join(self.state_manager.state.successful_patterns) or "None",
market_data=json.dumps(market_data),
)
)

signal = json.loads(response.content)

# Check if we should avoid this pattern
if self.state_manager.should_avoid_pattern(signal.get("pattern", "")):
signal["action"] = "HOLD"
signal["reasoning"] = f'Avoiding pattern "{signal["pattern"]}" due to previous failures'
signal["confidence"] = 0

return signal

async def record_outcome(self, signal: dict, success: bool) -> None:
self.state_manager.learn_from_trade(signal.get("pattern", "unknown"), success)
await self.state_manager.save()

Step 4: Complete Stateful Agent

import os
import asyncio
from web3 import Web3
from zeroquant import ZeroQuantClient

WETH_ADDRESS = "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"


class StatefulTradingAgent:
def __init__(self, client: ZeroQuantClient, vault_address: str):
self.client = client
self.state_manager = AgentStateManager(vault_address)
self.decision_engine = StatefulDecisionEngine(self.state_manager)

async def initialize(self) -> None:
await self.state_manager.load()
print("Agent state loaded")
print(self.state_manager.get_performance_summary())

async def execute_trading_cycle(self, market_data: dict) -> None:
# Get decision with historical context
signal = await self.decision_engine.analyze(market_data)

print(f"Decision: {signal['action']} ({signal['confidence']}% confidence)")
print(f"Reasoning: {signal['reasoning']}")

if signal["action"] == "HOLD" or signal["confidence"] < 70:
return

# Execute trade
entry_price = market_data["price"]
success = False

try:
tx = await self._execute_trade(signal, market_data)
await tx.wait()

# Wait for outcome (simplified - in reality you'd monitor the position)
exit_price = await self._wait_for_exit(signal, entry_price)
pnl = self._calculate_pnl(signal["action"], entry_price, exit_price)

success = pnl > 0

# Record the trade
self.state_manager.record_trade(Trade(
timestamp=datetime.now().timestamp(),
action=TradeAction(signal["action"]),
token=market_data["symbol"],
amount="1.0", # Simplified
price=entry_price,
pnl=pnl,
))

except Exception as e:
print(f"Trade failed: {e}")
success = False

# Learn from the outcome
await self.decision_engine.record_outcome(signal, success)

print(f"Trade {'succeeded' if success else 'failed'}")
print(self.state_manager.get_performance_summary())

async def _execute_trade(self, signal: dict, market: dict):
path = (
[WETH_ADDRESS, market["token_address"]]
if signal["action"] == "BUY"
else [market["token_address"], WETH_ADDRESS]
)
return await self.client.execute_swap(
amount_in=Web3.to_wei(0.1, "ether"),
path=path,
slippage_bps=100,
)

def _calculate_pnl(self, action: str, entry: float, exit: float) -> float:
if action == "BUY":
return ((exit - entry) / entry) * 100
else:
return ((entry - exit) / entry) * 100


# Usage
async def main():
w3 = Web3(Web3.HTTPProvider(os.getenv("RPC_URL")))
client = ZeroQuantClient(
web3=w3,
private_key=os.getenv("PRIVATE_KEY"),
factory_address=os.getenv("FACTORY_ADDRESS"),
permission_manager_address=os.getenv("PERMISSION_MANAGER_ADDRESS"),
)

agent = StatefulTradingAgent(client, vault_address="0x123...")
await agent.initialize()

# Run trading cycles
while True:
market_data = await fetch_market_data() # Your market data source
await agent.execute_trading_cycle(market_data)
await asyncio.sleep(60) # Check every minute


if __name__ == "__main__":
asyncio.run(main())

Memory Patterns Comparison

PatternUse CaseProsCons
Buffer MemoryShort conversationsSimple, fastLimited history
Summary MemoryLong sessionsCompressed contextLoses details
Vector StoreKnowledge retrievalSemantic searchComplexity
Redis/DBPersistent stateSurvives restartsExternal dependency

Best Practices

1. Memory Window Management

Don't let memory grow unbounded:

class BoundedMemory:
def __init__(self, max_items: int = 100):
self.max_items = max_items
self.items: list = []

def add(self, item: any) -> None:
self.items.append(item)
if len(self.items) > self.max_items:
# Summarize old items before removing
self._summarize_oldest(10)
self.items = self.items[10:]

def _summarize_oldest(self, count: int) -> None:
# Implementation for summarizing old items
pass

2. Periodic State Snapshots

Save state regularly:

import asyncio


async def periodic_save(state_manager: AgentStateManager):
while True:
await state_manager.save()
print("State checkpoint saved")
await asyncio.sleep(60) # Every minute


# Run alongside your main loop
asyncio.create_task(periodic_save(state_manager))

3. Memory Relevance Scoring

Prioritize relevant memories:

def get_relevant_memories(query: str, memories: list[dict]) -> list[dict]:
scored = [
{**m, "relevance": calculate_similarity(query, m["content"])}
for m in memories
]
scored.sort(key=lambda x: x["relevance"], reverse=True)
return scored[:5]

What's Next?