Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions __tests__/disconnects.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,170 @@ describe.each(testMatrix())(
});
},
);

describe.each(testMatrix())(
'procedures should handle calling a procedure on a closed transport ($transport.name transport, $codec.name codec)',
async ({ transport, codec }) => {
const opts = { codec: codec.codec };

const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups();
let getClientTransport: TestSetupHelpers['getClientTransport'];
let getServerTransport: TestSetupHelpers['getServerTransport'];
beforeEach(async () => {
const setup = await transport.setup({ client: opts, server: opts });
getClientTransport = setup.getClientTransport;
getServerTransport = setup.getServerTransport;

return async () => {
await postTestCleanup();
await setup.cleanup();
};
});

test('rpc', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { test: TestServiceSchema };
const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

clientTransport.close();

const result = await client.test.add.rpc({ n: 3 });
expect(result).toStrictEqual({
ok: false,
payload: {
code: UNEXPECTED_DISCONNECT_CODE,
message: 'transport is closed',
},
});

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('stream', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { test: TestServiceSchema };
const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

clientTransport.close();

const { reqWritable, resReadable } = client.test.echo.stream({});

const result = await readNextResult(resReadable);

expect(result).toStrictEqual({
ok: false,
payload: {
code: UNEXPECTED_DISCONNECT_CODE,
message: 'transport is closed',
},
});

expect(await isReadableDone(resReadable)).toEqual(true);

expect(reqWritable.isWritable()).toEqual(false);

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('subscription', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { test: TestServiceSchema };
const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

clientTransport.close();

const { resReadable } = client.test.unimplementedSubscription.subscribe(
{},
);

const result = await readNextResult(resReadable);
expect(result).toStrictEqual({
ok: false,
payload: {
code: UNEXPECTED_DISCONNECT_CODE,
message: 'transport is closed',
},
});

expect(await isReadableDone(resReadable)).toEqual(true);

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('upload', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
const services = { test: TestServiceSchema };
const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

clientTransport.close();

const { reqWritable, finalize } = client.test.unimplementedUpload.upload(
{},
);

expect(reqWritable.isWritable()).toEqual(false);

await expect(finalize()).resolves.toMatchObject({
ok: false,
payload: {
code: UNEXPECTED_DISCONNECT_CODE,
message: 'transport is closed',
},
});

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});
},
);
48 changes: 46 additions & 2 deletions router/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ function handleProc(
procedureName: string,
abortSignal?: AbortSignal,
): AnyProcReturn {
if (transport.getStatus() === 'closed') {
return getPreClosedReturnForProc(procType);
}

const session =
transport.sessions.get(serverId) ??
transport.createUnconnectedSession(serverId);
Expand Down Expand Up @@ -493,14 +497,54 @@ function handleProc(
reqWritable.close();
}

return getReturnForProc(procType, resReadable, reqWritable, transport.log);
}

/**
* We want to make sure all the return types are valid even if the transport is closed.
* So we return a result that is already closed with `UNEXPECTED_DISCONNECT_CODE`.
*/
function getPreClosedReturnForProc(procType: ValidProcType): AnyProcReturn {
const readable = new ReadableImpl<unknown, Static<BaseErrorSchemaType>>();
const err = Err({
code: UNEXPECTED_DISCONNECT_CODE,
message: `transport is closed`,
});

readable._pushValue(err);
readable._triggerClose();

const writable = new WritableImpl<Static<PayloadType>>({
writeCb: () => {
// noop
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably error as the writable we close on UNEXPECTED_DISCONNECT in the normal case iirc so this should match?

},
closeCb: () => {
// noop
},
});

writable.close();

return getReturnForProc(procType, readable, writable);
}

/**
* Given a proc type, returns the appropriate return type for the proc.
*/
function getReturnForProc(
procType: ValidProcType,
resReadable: ReadableImpl<unknown, Static<BaseErrorSchemaType>>,
reqWritable: WritableImpl<Static<PayloadType>>,
log?: Logger,
): AnyProcReturn {
if (procType === 'subscription') {
return {
resReadable: resReadable,
};
}

if (procType === 'rpc') {
return getSingleMessage(resReadable, transport.log);
return getSingleMessage(resReadable, log);
}

if (procType === 'upload') {
Expand All @@ -519,7 +563,7 @@ function handleProc(
reqWritable.close();
}

return getSingleMessage(resReadable, transport.log);
return getSingleMessage(resReadable, log);
},
};
}
Expand Down
Loading