Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .devcontainer/.devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"dockerComposeFile": [
"docker-compose.yml"
],
"service": "devcontainer",
"workspaceFolder": "/workspace",
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {}
}
}
21 changes: 21 additions & 0 deletions .devcontainer/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
version: '3.8'
volumes:
aip_db:
aip_ts:
services:
devcontainer:
image: mcr.microsoft.com/devcontainers/python:3.13
volumes:
- ..:/workspace:cached
- /var/run/docker.sock:/var/run/docker.sock
command: sleep infinity
environment:
- TZ=America/New_York

tailscale:
image: tailscale/tailscale:latest
restart: unless-stopped
environment:
- TS_STATE_DIR=/var/run/tailscale
volumes:
- aip_ts:/var/run/tailscale
3 changes: 3 additions & 0 deletions packages/atproto_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from atproto_client.client.client import Client
from atproto_client.client.session import Session, SessionEvent

from atproto_client.client.cli import main

__all__ = [
'AsyncClient',
'Client',
'models',
'SessionEvent',
'Session',
'main',
]
2 changes: 1 addition & 1 deletion packages/atproto_client/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _invoke(self, invoke_type: InvokeType, **kwargs: t.Any) -> Response:
class AsyncClientBase(_ClientCommonMethodsMixin):
"""Low-level methods are here."""

def __init__(self, base_url: t.Optional[str] = None, request: t.Optional[AsyncRequest] = None) -> None:
def __init__(self, base_url: t.Optional[str] = None, request: t.Optional[AsyncRequest] = None, *args: t.Any, **kwargs: t.Any) -> None:
if request is None:
request = AsyncRequest()

Expand Down
88 changes: 88 additions & 0 deletions packages/atproto_client/client/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import asyncio
import aiohttp
from typing import Callable, Any
from authlib.jose import JsonWebKey

from atproto_client import AsyncClient, Session, SessionEvent
from atproto_client.exceptions import UnauthorizedError


async def fetch_credentials(aip_jwk: str, aip_server: str = "https://grazeaip.tunn.dev") -> Session:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {aip_jwk}",
}
async with session.get(f"{aip_server}/internal/api/me", headers=headers) as response:
me_response = await response.json()

if "error" in me_response:
raise Exception(me_response["error"])

oauth_session_valid = me_response.get("oauth_session_valid", False)
app_password_session_valid = me_response.get(
"app_password_session_valid", False)

if not oauth_session_valid and not app_password_session_valid:
raise Exception("No valid session found")

async with session.get(f"{aip_server}/internal/api/credentials", headers=headers) as response:
session_response = await response.json()

if "error" in session_response:
raise Exception(session_response["error"])

if "type" in session_response and session_response["type"] == "dpop":
session = Session(
handle=me_response.get("handle", ""),
did=me_response.get("did", ""),
pds_endpoint=me_response.get("pds", None),
static_dpop_token=session_response.get("token", None),
static_dpop_issuer=session_response.get("issuer", None),
static_dpop_jwk=JsonWebKey.import_key(session_response.get("jwk", None)),
)
return session
raise Exception("oops")


async def retry_invoke(client: AsyncClient, session: Session, func: Callable, *args, **kwargs) -> Any:
for i in range(2):
try:
return await func(*args, **kwargs)
except UnauthorizedError as e:
if e.response is not None and e.response.status_code == 401:
if "www-authenticate" in e.response.headers and "use_dpop_nonce" in e.response.headers["www-authenticate"]:
session.static_dpop_nonce = e.response.headers["dpop-nonce"]
continue
raise e

async def realMain() -> None:
session = await fetch_credentials(os.getenv("AIP_JWK", ""), "https://auth.m.graze.social")

client = AsyncClient()
await client._set_session(SessionEvent.IMPORT, session)

create_record_args = {
"collection": "garden.lexicon.deeply-mouse.profile",
"repo": session.did,
"record": {
"$type": "garden.lexicon.deeply-mouse.profile",
"name": session.handle,
},
"validate": False,
}


created_record = await retry_invoke(client, session, client.com.atproto.repo.create_record, {**create_record_args})
print(created_record)

records = await retry_invoke(client, session, client.com.atproto.repo.list_records, {"collection": "garden.lexicon.deeply-mouse.profile", "repo": session.did})
print(records)


def main() -> None:
asyncio.run(realMain())


if __name__ == '__main__':
main()
73 changes: 61 additions & 12 deletions packages/atproto_client/client/methods_mixin/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import typing as t
from datetime import timedelta
from authlib.common.security import generate_token
from authlib.oauth2.rfc7636 import create_s256_code_challenge
from authlib.jose import jwt
import json

from atproto_client.client.methods_mixin.time import TimeMethodsMixin
from atproto_client.client.session import (
Expand Down Expand Up @@ -96,26 +100,38 @@ def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
self._session_dispatcher = SessionDispatcher()

def _register_auth_headers_source(self) -> None:
self.request.add_additional_headers_source(self._get_access_auth_headers)
self.request.add_additional_headers_source(
self._get_access_auth_headers)

def _should_refresh_session(self) -> bool:
if not self._session or not self._session.access_jwt_payload or not self._session.access_jwt_payload.exp:
if not self._session:
raise LoginRequiredError

if self._session.static_access_token or self._session.static_dpop_token:
return False

if self._session.access_jwt is None or self._session.access_jwt_payload is None or self._session.access_jwt_payload.exp is None:
raise LoginRequiredError

expired_at = self.get_time_from_timestamp(self._session.access_jwt_payload.exp)
expired_at = expired_at - timedelta(minutes=15) # let's update the token a bit earlier than required
expired_at = self.get_time_from_timestamp(
self._session.access_jwt_payload.exp)
# let's update the token a bit earlier than required
expired_at = expired_at - timedelta(minutes=15)

return self.get_current_time() > expired_at

def _set_or_update_session(self, session: SessionResponse, pds_endpoint: str) -> 'Session':
if not self._session:
self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)
if isinstance(session, Session):
self._session = session
else:
self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)
self._session_dispatcher.set_session(self._session)
self._register_auth_headers_source()
else:
Expand All @@ -137,10 +153,43 @@ def _set_session_common(self, session: SessionResponse, current_pds: str) -> Ses
self._update_pds_endpoint(pds_endpoint)
return self._set_or_update_session(session, pds_endpoint)

def _get_access_auth_headers(self) -> t.Dict[str, str]:
def _get_access_auth_headers(self, *args: t.Any, **kwargs: t.Any) -> t.Dict[str, str]:
if not self._session:
return {}

if self._session.static_access_token is not None:
return {'Authorization': f'Bearer {self._session.static_access_token}'}

if self._session.static_dpop_token is not None and self._session.static_dpop_jwk is not None:

htm = kwargs.get("method", "")
htu = kwargs.get("url", "")

dpop_pub_jwk = json.loads(
self._session.static_dpop_jwk.as_json(is_private=False))
now = self.get_current_time().timestamp()

body = {
"iss": self._session.static_dpop_issuer,
"iat": int(now),
"exp": int(now) + 10,
"jti": generate_token(),
"htm": htm,
"htu": htu,
"ath": create_s256_code_challenge(self._session.static_dpop_token),
}

if self._session.static_dpop_nonce is not None:
body["nonce"] = self._session.static_dpop_nonce

dpop_jwt_encoded = dpop_proof = jwt.encode(
{"typ": "dpop+jwt", "alg": "ES256", "jwk": dpop_pub_jwk}, body, self._session.static_dpop_jwk).decode("utf-8")

return {
"Authorization": f"DPoP {self._session.static_dpop_token}",
"DPoP": dpop_jwt_encoded,
}

return {'Authorization': f'Bearer {self._session.access_jwt}'}

def _get_refresh_auth_headers(self) -> t.Dict[str, str]:
Expand Down
48 changes: 34 additions & 14 deletions packages/atproto_client/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass
from enum import Enum

from authlib.jose.rfc7517 import Key

import typing_extensions as te
from atproto_core.did_doc import DidDocument, is_valid_did_doc
from atproto_server.auth.jwt import get_jwt_payload
Expand All @@ -30,12 +32,14 @@ class SessionEvent(Enum):
]

SessionChangeCallback = t.Callable[[SessionEvent, 'Session'], None]
AsyncSessionChangeCallback = t.Callable[[SessionEvent, 'Session'], t.Coroutine[t.Any, t.Any, None]]
AsyncSessionChangeCallback = t.Callable[[
SessionEvent, 'Session'], t.Coroutine[t.Any, t.Any, None]]


def _session_exists(session: t.Optional['Session']) -> te.TypeGuard['Session']:
if not session:
raise ValueError('Session does not exists. It is not possible to dispatch session change event.')
raise ValueError(
'Session does not exists. It is not possible to dispatch session change event.')

return isinstance(session, Session)

Expand All @@ -45,7 +49,8 @@ def __init__(self, session: t.Optional['Session'] = None) -> None:
self._session: t.Optional['Session'] = session

self._on_session_change_callbacks: t.List[SessionChangeCallback] = []
self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = []
self._on_session_change_async_callbacks: t.List[AsyncSessionChangeCallback] = [
]

def set_session(self, session: 'Session') -> None:
self._session = session
Expand All @@ -60,7 +65,8 @@ def dispatch_session_change(self, event: SessionEvent) -> None:
self._call_on_session_change_callbacks(event)

async def dispatch_session_change_async(self, event: SessionEvent) -> None:
self._call_on_session_change_callbacks(event) # Allow synchronous callbacks in the async client
# Allow synchronous callbacks in the async client
self._call_on_session_change_callbacks(event)
await self._call_on_session_change_callbacks_async(event)

def _call_on_session_change_callbacks(self, event: SessionEvent) -> None:
Expand All @@ -76,7 +82,8 @@ async def _call_on_session_change_callbacks_async(self, event: SessionEvent) ->

coroutines: t.List[t.Coroutine[t.Any, t.Any, None]] = []
for on_session_change_async_callback in self._on_session_change_async_callbacks:
coroutines.append(on_session_change_async_callback(event, session_copy))
coroutines.append(
on_session_change_async_callback(event, session_copy))

await asyncio.gather(*coroutines)

Expand All @@ -85,16 +92,28 @@ async def _call_on_session_change_callbacks_async(self, event: SessionEvent) ->
class Session:
handle: str
did: str
access_jwt: str
refresh_jwt: str
pds_endpoint: t.Optional[str] = 'https://bsky.social' # Backward compatibility for old sessions
access_jwt: t.Optional[str] = None
refresh_jwt: t.Optional[str] = None
# Backward compatibility for old sessions
pds_endpoint: t.Optional[str] = 'https://bsky.social'

static_access_token: t.Optional[str] = None

static_dpop_token: t.Optional[str] = None
static_dpop_jwk: t.Optional[Key] = None
static_dpop_issuer: t.Optional[str] = None
static_dpop_nonce: t.Optional[str] = None

@property
def access_jwt_payload(self) -> 'JwtPayload':
def access_jwt_payload(self) -> t.Union['JwtPayload', None]:
if self.access_jwt is None:
return None
return get_jwt_payload(self.access_jwt)

@property
def refresh_jwt_payload(self) -> 'JwtPayload':
def refresh_jwt_payload(self) -> t.Union['JwtPayload', None]:
if self.refresh_jwt is None:
return None
return get_jwt_payload(self.refresh_jwt)

def __repr__(self) -> str:
Expand All @@ -107,8 +126,8 @@ def encode(self) -> str:
payload = [
self.handle,
self.did,
self.access_jwt,
self.refresh_jwt,
self.access_jwt or "",
self.refresh_jwt or "",
self.pds_endpoint,
]
return _SESSION_STRING_SEPARATOR.join(payload)
Expand All @@ -122,11 +141,12 @@ def decode(cls, session_string: str) -> 'Session':
handle, did, access_jwt, refresh_jwt = fields
return cls(handle, did, access_jwt, refresh_jwt)

handle, did, access_jwt, refresh_jwt, pds_endpoint = session_string.split(_SESSION_STRING_SEPARATOR)
handle, did, access_jwt, refresh_jwt, pds_endpoint = session_string.split(
_SESSION_STRING_SEPARATOR)
return cls(handle, did, access_jwt, refresh_jwt, pds_endpoint)

def copy(self) -> 'Session':
return Session(self.handle, self.did, self.access_jwt, self.refresh_jwt, self.pds_endpoint)
return Session(self.handle, self.did, self.access_jwt, self.refresh_jwt, self.pds_endpoint, self.static_access_token, self.static_dpop_token, self.static_dpop_jwk, self.static_dpop_issuer, self.static_dpop_nonce)

#: Alias for :attr:`encode`
export = encode
Expand Down
Loading