Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions langgraph-checkpoint-oceanbase-plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ In LangGraph's definition, Checkpointer refers to short-term memory (i.e. memory
```python
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, MessagesState, START
from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver
from langgraph.checkpoint.oceanbase.pyoceanbase import PyOceanBaseSaver
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import HumanMessage
model = init_chat_model(model="qwen-max-latest", api_key="xxx",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", model_provider="openai",temperature=0)
DB_URI = "mysql://username:password@ip:port/database"
with PyMySQLSaver.from_conn_string(DB_URI) as checkpointer:
with PyOceanBaseSaver.from_conn_string(DB_URI) as checkpointer:
checkpointer.setup()

def call_model(state: MessagesState):
Expand Down Expand Up @@ -63,7 +63,7 @@ with PyMySQLSaver.from_conn_string(DB_URI) as checkpointer:
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_store
from langgraph.prebuilt import create_react_agent
from langgraph.store.mysql import PyMySQLStore
from langgraph.store.oceanbase import PyOceanBaseStore
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
from typing_extensions import TypedDict
Expand All @@ -86,7 +86,7 @@ def save_user_info(user_info: UserInfo, config: RunnableConfig) -> str:
return "Successfully saved user info."


with PyMySQLStore.from_conn_string(DB_URI) as store:
with PyOceanBaseStore.from_conn_string(DB_URI) as store:
store.setup()
agent=create_react_agent(
model=model,
Expand Down
8 changes: 4 additions & 4 deletions langgraph-checkpoint-oceanbase-plugin/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ langgraph-checkpoint-oceanbase 已经上传到 PyPI。可以使用下面的命
```python
from langchain.chat_models import init_chat_model
from langgraph.graph import StateGraph, MessagesState, START
from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver
from langgraph.checkpoint.oceanbase.pyoceanbase import PyOceanBaseSaver
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import HumanMessage
model = init_chat_model(model="qwen-max-latest", api_key="xxx",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", model_provider="openai",temperature=0)
DB_URI = "mysql://username:password@ip:port/database"
with PyMySQLSaver.from_conn_string(DB_URI) as checkpointer:
with PyOceanBaseSaver.from_conn_string(DB_URI) as checkpointer:
checkpointer.setup()

def call_model(state: MessagesState):
Expand Down Expand Up @@ -63,7 +63,7 @@ with PyMySQLSaver.from_conn_string(DB_URI) as checkpointer:
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_store
from langgraph.prebuilt import create_react_agent
from langgraph.store.mysql import PyMySQLStore
from langgraph.store.oceanbase import PyOceanBaseStore
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
from typing_extensions import TypedDict
Expand All @@ -86,7 +86,7 @@ def save_user_info(user_info: UserInfo, config: RunnableConfig) -> str:
return "Successfully saved user info."


with PyMySQLStore.from_conn_string(DB_URI) as store:
with PyOceanBaseStore.from_conn_string(DB_URI) as store:
store.setup()
agent=create_react_agent(
model=model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import pymysql
from sqlalchemy import Engine, create_engine

from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
from langgraph.checkpoint.mysql.asyncmy import AsyncMySaver
from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver
from langgraph.checkpoint.oceanbase.aio import AIOMySQLSaver
from langgraph.checkpoint.oceanbase.asyncmy import AsyncMySaver
from langgraph.checkpoint.oceanbase.pyoceanbase import PyOceanBaseSaver

DEFAULT_MYSQL_URI = "mysql://mysql:mysql@localhost:5441/"

Expand All @@ -23,19 +23,19 @@ def _checkpointer_pymysql():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with pymysql.connect(**PyOceanBaseSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
with PyMySQLSaver.from_conn_string(
with PyOceanBaseSaver.from_conn_string(
DEFAULT_MYSQL_URI + database
) as checkpointer:
checkpointer.setup()
yield checkpointer
finally:
# drop unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with pymysql.connect(**PyOceanBaseSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE {database}")

Expand All @@ -45,17 +45,17 @@ def _checkpointer_pymysql_pool():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with pymysql.connect(**PyOceanBaseSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
pool = get_pymysql_sqlalchemy_engine(DEFAULT_MYSQL_URI + database)
checkpointer = PyMySQLSaver(pool.raw_connection)
checkpointer = PyOceanBaseSaver(pool.raw_connection)
checkpointer.setup()
yield checkpointer
finally:
# drop unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with pymysql.connect(**PyOceanBaseSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE {database}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import pymysql
from sqlalchemy import Engine, create_engine

from langgraph.store.mysql.aio import AIOMySQLStore
from langgraph.store.mysql.asyncmy import AsyncMyStore
from langgraph.store.mysql.pymysql import PyMySQLStore
from langgraph.store.oceanbase.aio import AIOMySQLStore
from langgraph.store.oceanbase.asyncmy import AsyncMyStore
from langgraph.store.oceanbase.pyoceanbase import PyOceanBaseStore

DEFAULT_MYSQL_URI = "mysql://mysql:mysql@localhost:5441/"

Expand All @@ -23,17 +23,17 @@ def _store_pymysql():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with pymysql.connect(**PyOceanBaseStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
# yield store
with PyMySQLStore.from_conn_string(DEFAULT_MYSQL_URI + database) as store:
with PyOceanBaseStore.from_conn_string(DEFAULT_MYSQL_URI + database) as store:
store.setup()
yield store
finally:
# drop unique db
with pymysql.connect(**PyMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with pymysql.connect(**PyOceanBaseStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE {database}")

Expand All @@ -43,18 +43,18 @@ def _store_pymysql_pool():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with pymysql.connect(**PyOceanBaseStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
# yield store
engine = get_pymysql_sqlalchemy_engine(DEFAULT_MYSQL_URI + database)
store = PyMySQLStore(engine.raw_connection)
store = PyOceanBaseStore(engine.raw_connection)
store.setup()
yield store
finally:
# drop unique db
with pymysql.connect(**PyMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with pymysql.connect(**PyOceanBaseStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE {database}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
get_checkpoint_id,
get_checkpoint_metadata,
)
from langgraph.checkpoint.mysql import _internal
from langgraph.checkpoint.mysql.base import BaseMySQLSaver
from langgraph.checkpoint.mysql.utils import (
from langgraph.checkpoint.oceanbase import _internal
from langgraph.checkpoint.oceanbase.base import BaseMySQLSaver
from langgraph.checkpoint.oceanbase.utils import (
deserialize_channel_values,
deserialize_pending_sends,
deserialize_pending_writes,
Expand Down Expand Up @@ -115,7 +115,7 @@ def list(
Iterator[CheckpointTuple]: An iterator of checkpoint tuples.

Examples:
>>> from langgraph.checkpoint.mysql import PyMySQLSaver
>>> from langgraph.checkpoint.oceanbase import PyOceanBaseSaver
>>> DB_URI = "mysql://mysql:mysql@localhost:5432/mysql"
>>> with PyMySQLSaver.from_conn_string(DB_URI) as memory:
... # Run a graph, then list the checkpoints
Expand Down Expand Up @@ -284,9 +284,9 @@ def put(

Examples:

>>> from langgraph.checkpoint.mysql import PyMySQLSaver
>>> from langgraph.checkpoint.oceanbase import PyOceanBaseSaver
>>> DB_URI = "mysql://mysql:mysql@localhost:5432/mysql"
>>> with PyMySQLSaver.from_conn_string(DB_URI) as memory:
>>> with PyOceanBaseSaver.from_conn_string(DB_URI) as memory:
>>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
>>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}}
>>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import aiomysql # type: ignore
from typing_extensions import Self, override

from langgraph.checkpoint.mysql import _ainternal
from langgraph.checkpoint.mysql.aio_base import BaseAsyncMySQLSaver
from langgraph.checkpoint.mysql.shallow import BaseShallowAsyncMySQLSaver
from langgraph.checkpoint.oceanbase import _ainternal
from langgraph.checkpoint.oceanbase.aio_base import BaseAsyncMySQLSaver
from langgraph.checkpoint.oceanbase.shallow import BaseShallowAsyncMySQLSaver
from langgraph.checkpoint.serde.base import SerializerProtocol

Conn = _ainternal.Conn[aiomysql.Connection] # For backward compatibility
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
get_checkpoint_id,
get_checkpoint_metadata,
)
from langgraph.checkpoint.mysql import _ainternal
from langgraph.checkpoint.mysql.base import BaseMySQLSaver
from langgraph.checkpoint.mysql.utils import (
from langgraph.checkpoint.oceanbase import _ainternal
from langgraph.checkpoint.oceanbase.base import BaseMySQLSaver
from langgraph.checkpoint.oceanbase.utils import (
deserialize_channel_values,
deserialize_pending_sends,
deserialize_pending_writes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from asyncmy.cursors import DictCursor # type: ignore
from typing_extensions import Self, override

from langgraph.checkpoint.mysql.aio_base import BaseAsyncMySQLSaver
from langgraph.checkpoint.mysql.shallow import BaseShallowAsyncMySQLSaver
from langgraph.checkpoint.oceanbase.aio_base import BaseAsyncMySQLSaver
from langgraph.checkpoint.oceanbase.shallow import BaseShallowAsyncMySQLSaver
from langgraph.checkpoint.serde.base import SerializerProtocol


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
CheckpointMetadata,
get_checkpoint_id,
)
from langgraph.checkpoint.mysql.utils import mysql_mariadb_branch
from langgraph.checkpoint.oceanbase.utils import mysql_mariadb_branch
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import TASKS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
from pymysql.cursors import DictCursor
from typing_extensions import Self, override

from langgraph.checkpoint.mysql import BaseSyncMySQLSaver, _internal
from langgraph.checkpoint.mysql import Conn as BaseConn
from langgraph.checkpoint.mysql.shallow import BaseShallowSyncMySQLSaver
from langgraph.checkpoint.oceanbase import BaseSyncMySQLSaver, _internal
from langgraph.checkpoint.oceanbase import Conn as BaseConn
from langgraph.checkpoint.oceanbase.shallow import BaseShallowSyncMySQLSaver
from langgraph.checkpoint.serde.base import SerializerProtocol

Conn = BaseConn[pymysql.Connection] # type: ignore


class PyMySQLSaver(BaseSyncMySQLSaver[pymysql.Connection, DictCursor]):
"""Checkpointer that stores checkpoints in a MySQL database."""
class PyOceanBaseSaver(BaseSyncMySQLSaver[pymysql.Connection, DictCursor]):
"""Checkpointer that stores checkpoints in an OceanBase database."""

@staticmethod
def parse_conn_string(conn_string: str) -> dict[str, Any]:
Expand Down Expand Up @@ -100,7 +100,7 @@ def from_conn_string(
conn_string=mysql://user:password@localhost/db?unix_socket=/path/to/socket
"""
with pymysql.connect(
**PyMySQLSaver.parse_conn_string(conn_string),
**PyOceanBaseSaver.parse_conn_string(conn_string),
autocommit=True,
) as conn:
yield cls(conn)
Expand All @@ -111,4 +111,4 @@ def _get_cursor_from_connection(conn: pymysql.Connection) -> DictCursor:
return conn.cursor(DictCursor)


__all__ = ["PyMySQLSaver", "ShallowPyMySQLSaver", "Conn"]
__all__ = ["PyOceanBaseSaver", "ShallowPyMySQLSaver", "Conn"]
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
CheckpointTuple,
get_checkpoint_metadata,
)
from langgraph.checkpoint.mysql import _ainternal, _internal
from langgraph.checkpoint.mysql.base import BaseMySQLSaver
from langgraph.checkpoint.mysql.utils import (
from langgraph.checkpoint.oceanbase import _ainternal, _internal
from langgraph.checkpoint.oceanbase.base import BaseMySQLSaver
from langgraph.checkpoint.oceanbase.utils import (
deserialize_channel_values,
deserialize_pending_sends,
deserialize_pending_writes,
Expand Down Expand Up @@ -432,7 +432,7 @@ def put(

Examples:

>>> from langgraph.checkpoint.mysql import PyMySQLSaver
>>> from langgraph.checkpoint.oceanbase import PyOceanBaseSaver
>>> DB_URI = "mysql://mysql:mysql@localhost:5432/mysql"
>>> with ShallowPyMySQLSaver.from_conn_string(DB_URI) as memory:
>>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from langgraph.store.oceanbase.aio import AIOMySQLStore
from langgraph.store.oceanbase.asyncmy import AsyncMyStore
from langgraph.store.oceanbase.pyoceanbase import PyOceanBaseStore

__all__ = ["AIOMySQLStore", "AsyncMyStore", "PyOceanBaseStore"]
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import aiomysql # type: ignore
from typing_extensions import Self, override

from langgraph.store.mysql.aio_base import BaseAsyncMySQLStore
from langgraph.store.oceanbase.aio_base import BaseAsyncMySQLStore

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import orjson

from langgraph.checkpoint.mysql import _ainternal
from langgraph.checkpoint.oceanbase import _ainternal
from langgraph.store.base import (
GetOp,
ListNamespacesOp,
Expand All @@ -18,7 +18,7 @@
SearchOp,
)
from langgraph.store.base.batch import AsyncBatchedBaseStore
from langgraph.store.mysql.base import (
from langgraph.store.oceanbase.base import (
BaseMySQLStore,
Row,
_decode_ns_bytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from asyncmy.cursors import DictCursor # type: ignore
from typing_extensions import Self, override

from langgraph.store.mysql.aio_base import BaseAsyncMySQLStore
from langgraph.store.oceanbase.aio_base import BaseAsyncMySQLStore

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import orjson
from typing_extensions import TypedDict

from langgraph.checkpoint.mysql import _ainternal as _ainternal
from langgraph.checkpoint.mysql import _internal as _internal
from langgraph.checkpoint.mysql.utils import mysql_mariadb_branch
from langgraph.checkpoint.oceanbase import _ainternal as _ainternal
from langgraph.checkpoint.oceanbase import _internal as _internal
from langgraph.checkpoint.oceanbase.utils import mysql_mariadb_branch
from langgraph.store.base import (
BaseStore,
GetOp,
Expand Down
Loading
Loading