diff --git a/package-lock.json b/package-lock.json index baaf96b82..c2ecce1fb 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,7 +9,6 @@ "version": "0.24.0", "license": "Apache-2.0", "dependencies": { - "@viamrobotics/rpc": "^0.2.6", "exponential-backoff": "^3.1.1" }, "devDependencies": { @@ -1755,26 +1754,6 @@ } } }, - "node_modules/@viamrobotics/rpc": { - "version": "0.2.6", - "resolved": "https://registry.npmjs.org/@viamrobotics/rpc/-/rpc-0.2.6.tgz", - "integrity": "sha512-0zEW6P+kxvYJdE/DsQSRjnOZZPM9HoKWvZBoGQUVxEX0fxo6j3kxoTn7HjOcGSp6yC3UQbVW+nlwzslI9VslPQ==", - "dependencies": { - "@improbable-eng/grpc-web": "^0.13.0", - "google-protobuf": "^3.14.0" - } - }, - "node_modules/@viamrobotics/rpc/node_modules/@improbable-eng/grpc-web": { - "version": "0.13.0", - "resolved": "https://registry.npmjs.org/@improbable-eng/grpc-web/-/grpc-web-0.13.0.tgz", - "integrity": "sha512-vaxxT+Qwb7GPqDQrBV4vAAfH0HywgOLw6xGIKXd9Q8hcV63CQhmS3p4+pZ9/wVvt4Ph3ZDK9fdC983b9aGMUFg==", - "dependencies": { - "browser-headers": "^0.4.0" - }, - "peerDependencies": { - "google-protobuf": "^3.2.0" - } - }, "node_modules/@viamrobotics/typescript-config": { "version": "0.1.0", "resolved": "https://registry.npmjs.org/@viamrobotics/typescript-config/-/typescript-config-0.1.0.tgz", @@ -2248,7 +2227,8 @@ "node_modules/browser-headers": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/browser-headers/-/browser-headers-0.4.1.tgz", - "integrity": "sha512-CA9hsySZVo9371qEHjHZtYxV2cFtVj5Wj/ZHi8ooEsrtm4vOnl9Y9HmyYWk9q+05d7K3rdoAE0j3MVEFVvtQtg==" + "integrity": "sha512-CA9hsySZVo9371qEHjHZtYxV2cFtVj5Wj/ZHi8ooEsrtm4vOnl9Y9HmyYWk9q+05d7K3rdoAE0j3MVEFVvtQtg==", + "dev": true }, "node_modules/browserslist": { "version": "4.23.0", @@ -4827,7 +4807,8 @@ "node_modules/google-protobuf": { "version": "3.21.2", "resolved": "https://registry.npmjs.org/google-protobuf/-/google-protobuf-3.21.2.tgz", - "integrity": "sha512-3MSOYFO5U9mPGikIYCzK0SaThypfGgS6bHqrUGXG3DPHCrb+txNqeEcns1W0lkGfk0rCyNXm7xB9rMxnCiZOoA==" + "integrity": "sha512-3MSOYFO5U9mPGikIYCzK0SaThypfGgS6bHqrUGXG3DPHCrb+txNqeEcns1W0lkGfk0rCyNXm7xB9rMxnCiZOoA==", + "dev": true }, "node_modules/got": { "version": "9.6.0", diff --git a/package.json b/package.json index 383b101e3..63aaab0a1 100644 --- a/package.json +++ b/package.json @@ -46,7 +46,6 @@ }, "homepage": "https://github.com/viamrobotics/viam-typescript-sdk#readme", "dependencies": { - "@viamrobotics/rpc": "^0.2.6", "exponential-backoff": "^3.1.1" }, "devDependencies": { diff --git a/src/app/viam-transport.ts b/src/app/viam-transport.ts index d537efff2..4f7028425 100644 --- a/src/app/viam-transport.ts +++ b/src/app/viam-transport.ts @@ -1,5 +1,5 @@ import { grpc } from '@improbable-eng/grpc-web'; -import { dialDirect } from '@viamrobotics/rpc'; +import { dialDirect } from '../rpc'; import { AuthenticateRequest, diff --git a/src/robot/client.ts b/src/robot/client.ts index 81c155fe6..3c36b38ab 100644 --- a/src/robot/client.ts +++ b/src/robot/client.ts @@ -1,6 +1,5 @@ /* eslint-disable max-classes-per-file */ import { grpc } from '@improbable-eng/grpc-web'; -import { dialDirect, dialWebRTC, type DialOptions } from '@viamrobotics/rpc'; import { backOff } from 'exponential-backoff'; import { Duration } from 'google-protobuf/google/protobuf/duration_pb'; import { isCredential, type Credentials } from '../app/viam-transport'; @@ -32,6 +31,7 @@ import { SensorsServiceClient } from '../gen/service/sensors/v1/sensors_pb_servi import { SLAMServiceClient } from '../gen/service/slam/v1/slam_pb_service'; import { VisionServiceClient } from '../gen/service/vision/v1/vision_pb_service'; import { ViamResponseStream } from '../responses'; +import { dialDirect, dialWebRTC, type DialOptions } from '../rpc'; import { MetadataTransport, encodeResourceName, promisify } from '../utils'; import GRPCConnectionManager from './grpc-connection-manager'; import type { Robot, RobotStatusStream } from './robot'; diff --git a/src/robot/session-manager.test.ts b/src/robot/session-manager.test.ts index ece67396b..91dd5b629 100644 --- a/src/robot/session-manager.test.ts +++ b/src/robot/session-manager.test.ts @@ -1,10 +1,10 @@ // @vitest-environment happy-dom -import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { ConnectionClosedError } from '@viamrobotics/rpc'; -import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport'; import { grpc } from '@improbable-eng/grpc-web'; +import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; import { RobotServiceClient } from '../gen/robot/v1/robot_pb_service'; +import { ConnectionClosedError } from '../rpc'; vi.mock('../gen/robot/v1/robot_pb_service'); import SessionManager from './session-manager'; diff --git a/src/robot/session-manager.ts b/src/robot/session-manager.ts index ac4274d1a..25bcfa70d 100644 --- a/src/robot/session-manager.ts +++ b/src/robot/session-manager.ts @@ -1,10 +1,10 @@ -import { ConnectionClosedError } from '@viamrobotics/rpc'; import { grpc } from '@improbable-eng/grpc-web'; +import robotApi from '../gen/robot/v1/robot_pb'; import { RobotServiceClient, type ServiceError, } from '../gen/robot/v1/robot_pb_service'; -import robotApi from '../gen/robot/v1/robot_pb'; +import { ConnectionClosedError } from '../rpc'; import SessionTransport from './session-transport'; const timeoutBlob = new Blob( diff --git a/src/robot/session-transport.ts b/src/robot/session-transport.ts index f93581784..abbbefa73 100644 --- a/src/robot/session-transport.ts +++ b/src/robot/session-transport.ts @@ -1,5 +1,5 @@ -import { GRPCError } from '@viamrobotics/rpc'; import { grpc } from '@improbable-eng/grpc-web'; +import { GRPCError } from '../rpc'; import type SessionManager from './session-manager'; export default class SessionTransport implements grpc.Transport { diff --git a/src/rpc/base-channel.ts b/src/rpc/base-channel.ts new file mode 100644 index 000000000..d662b5e67 --- /dev/null +++ b/src/rpc/base-channel.ts @@ -0,0 +1,81 @@ +import type { ProtobufMessage } from '@improbable-eng/grpc-web/dist/typings/message'; +import { ConnectionClosedError } from './connection-closed-error'; + +export class BaseChannel { + public readonly ready: Promise; + + private readonly peerConn: RTCPeerConnection; + private readonly dataChannel: RTCDataChannel; + private pResolve: ((value: unknown) => void) | undefined; + private pReject: ((reason?: unknown) => void) | undefined; + + private closed = false; + private closedReason: Error | undefined; + + protected maxDataChannelSize = 65_535; + + constructor(peerConn: RTCPeerConnection, dataChannel: RTCDataChannel) { + this.peerConn = peerConn; + this.dataChannel = dataChannel; + + this.ready = new Promise((resolve, reject) => { + this.pResolve = resolve; + this.pReject = reject; + }); + + dataChannel.addEventListener('open', () => this.onChannelOpen()); + dataChannel.addEventListener('close', () => this.onChannelClose()); + dataChannel.addEventListener('error', (ev) => { + this.onChannelError(ev); + }); + + peerConn.addEventListener('iceconnectionstatechange', () => { + const state = peerConn.iceConnectionState; + if ( + !(state === 'failed' || state === 'disconnected' || state === 'closed') + ) { + return; + } + this.pReject?.(new Error(`ICE connection failed with state: ${state}`)); + }); + } + + public close() { + this.closeWithReason(undefined); + } + + public isClosed() { + return this.closed; + } + + public isClosedReason() { + return this.closedReason; + } + + protected closeWithReason(err?: Error) { + if (this.closed) { + return; + } + this.closed = true; + this.closedReason = err; + this.pReject?.(err); + this.peerConn.close(); + } + + private onChannelOpen() { + this.pResolve?.(undefined); + } + + private onChannelClose() { + this.closeWithReason(new ConnectionClosedError('data channel closed')); + } + + private onChannelError(ev: any) { + console.error('channel error', ev); + this.closeWithReason(new Error(ev)); + } + + protected write(msg: ProtobufMessage) { + this.dataChannel.send(msg.serializeBinary()); + } +} diff --git a/src/rpc/base-stream.ts b/src/rpc/base-stream.ts new file mode 100644 index 000000000..374a1e0be --- /dev/null +++ b/src/rpc/base-stream.ts @@ -0,0 +1,62 @@ +import type { grpc } from '@improbable-eng/grpc-web'; +import type { PacketMessage, Stream } from '../gen/proto/rpc/webrtc/v1/grpc_pb'; + +// MaxMessageSize (2^25) is the maximum size a gRPC message can be. +const MaxMessageSize = 33_554_432; + +export class BaseStream { + protected readonly stream: Stream; + private readonly onDone: (id: number) => void; + protected readonly opts: grpc.TransportOptions; + protected closed = false; + private readonly packetBuf: Uint8Array[] = []; + private packetBufSize = 0; + private err: Error | undefined; + + constructor( + stream: Stream, + onDone: (id: number) => void, + opts: grpc.TransportOptions + ) { + this.stream = stream; + this.onDone = onDone; + this.opts = opts; + } + + public closeWithRecvError(err?: Error) { + if (this.closed) { + return; + } + this.closed = true; + this.err = err; + this.onDone(this.stream.getId()); + // pretty sure passing the error does nothing. + this.opts.onEnd(this.err); + } + + protected processPacketMessage(msg: PacketMessage): Uint8Array | undefined { + const data = msg.getData_asU8(); + if (data.length + this.packetBufSize > MaxMessageSize) { + this.packetBuf.length = 0; + this.packetBufSize = 0; + console.error( + `message size larger than max ${MaxMessageSize}; discarding` + ); + return undefined; + } + this.packetBuf.push(data); + this.packetBufSize += data.length; + if (msg.getEom()) { + const pktData = new Uint8Array(this.packetBufSize); + let position = 0; + for (const partialData of this.packetBuf) { + pktData.set(partialData, position); + position += partialData.length; + } + this.packetBuf.length = 0; + this.packetBufSize = 0; + return pktData; + } + return undefined; + } +} diff --git a/src/rpc/client-channel.ts b/src/rpc/client-channel.ts new file mode 100644 index 000000000..f5917d640 --- /dev/null +++ b/src/rpc/client-channel.ts @@ -0,0 +1,170 @@ +import type { grpc } from '@improbable-eng/grpc-web'; +import { + Request, + RequestHeaders, + RequestMessage, + Response, + Stream, +} from '../gen/proto/rpc/webrtc/v1/grpc_pb'; +import { BaseChannel } from './base-channel'; +import { ClientStream } from './client-stream'; +import { ConnectionClosedError } from './connection-closed-error'; + +// MaxStreamCount is the max number of streams a channel can have. +const MaxStreamCount = 256; + +interface activeClienStream { + cs: ClientStream; +} + +export class ClientChannel extends BaseChannel { + private streamIDCounter = 0; + private readonly streams = new Map(); + + constructor(pc: RTCPeerConnection, dc: RTCDataChannel) { + super(pc, dc); + dc.addEventListener('message', (event: MessageEvent<'message'>) => { + this.onChannelMessage(event); + }); + pc.addEventListener('iceconnectionstatechange', () => { + const state = pc.iceConnectionState; + if ( + !(state === 'failed' || state === 'disconnected' || state === 'closed') + ) { + return; + } + this.onConnectionTerminated(); + }); + dc.addEventListener('close', () => this.onConnectionTerminated()); + } + + public transportFactory(): grpc.TransportFactory { + return (opts: grpc.TransportOptions) => { + return this.newStream(this.nextStreamID(), opts); + }; + } + + private onConnectionTerminated() { + // we may call this twice but we know closed will be true at this point. + this.closeWithReason(new ConnectionClosedError('data channel closed')); + const err = new ConnectionClosedError('connection terminated'); + for (const stream of this.streams.values()) { + stream.cs.closeWithRecvError(err); + } + } + + private onChannelMessage(event: MessageEvent) { + let resp: Response; + try { + resp = Response.deserializeBinary( + new Uint8Array(event.data as ArrayBuffer) + ); + } catch (error) { + console.error('error deserializing message', error); + return; + } + + const stream = resp.getStream(); + if (stream === undefined) { + console.error('no stream id; discarding'); + return; + } + + const id = stream.getId(); + const activeStream = this.streams.get(id); + if (activeStream === undefined) { + console.error('no stream for id; discarding', 'id', id); + return; + } + activeStream.cs.onResponse(resp); + } + + private nextStreamID(): Stream { + const stream = new Stream(); + const thisStreamId = this.streamIDCounter; + this.streamIDCounter += 1; + stream.setId(thisStreamId); + return stream; + } + + private newStream( + stream: Stream, + opts: grpc.TransportOptions + ): grpc.Transport { + if (this.isClosed()) { + return new FailingClientStream( + new ConnectionClosedError('connection closed'), + opts + ); + } + let activeStream = this.streams.get(stream.getId()); + if (activeStream === undefined) { + if (Object.keys(this.streams).length > MaxStreamCount) { + return new FailingClientStream(new Error('stream limit hit'), opts); + } + const clientStream = new ClientStream( + this, + stream, + (id: number) => this.removeStreamByID(id), + opts + ); + activeStream = { cs: clientStream }; + this.streams.set(stream.getId(), activeStream); + } + return activeStream.cs; + } + + private removeStreamByID(id: number) { + this.streams.delete(id); + } + + public writeHeaders(stream: Stream, headers: RequestHeaders) { + const request = new Request(); + request.setStream(stream); + request.setHeaders(headers); + this.write(request); + } + + public writeMessage(stream: Stream, msg: RequestMessage) { + const request = new Request(); + request.setStream(stream); + request.setMessage(msg); + this.write(request); + } + + public writeReset(stream: Stream) { + const request = new Request(); + request.setStream(stream); + request.setRstStream(true); + this.write(request); + } +} + +class FailingClientStream implements grpc.Transport { + private readonly err: Error; + private readonly opts: grpc.TransportOptions; + + constructor(err: Error, opts: grpc.TransportOptions) { + this.err = err; + this.opts = opts; + } + + public start() { + setTimeout(() => this.opts.onEnd(this.err)); + } + + // eslint-disable-next-line class-methods-use-this + public sendMessage() { + // do nothing. + } + + // eslint-disable-next-line class-methods-use-this + public finishSend() { + // do nothing. + } + + // eslint-disable-next-line class-methods-use-this + public cancel() { + // do nothing. + } +} diff --git a/src/rpc/client-stream.ts b/src/rpc/client-stream.ts new file mode 100644 index 000000000..ba0dd1651 --- /dev/null +++ b/src/rpc/client-stream.ts @@ -0,0 +1,280 @@ +import { grpc } from '@improbable-eng/grpc-web'; +import { + Metadata, + PacketMessage, + RequestHeaders, + RequestMessage, + Response, + ResponseHeaders, + ResponseMessage, + ResponseTrailers, + Stream, + Strings, +} from '../gen/proto/rpc/webrtc/v1/grpc_pb'; +import { BaseStream } from './base-stream'; +import type { ClientChannel } from './client-channel'; +import { GRPCError } from './grpc-error'; + +// see golang/client_stream.go +const maxRequestMessagePacketDataSize = 16_373; + +export class ClientStream extends BaseStream implements grpc.Transport { + private readonly channel: ClientChannel; + private headersReceived = false; + private trailersReceived = false; + + constructor( + channel: ClientChannel, + stream: Stream, + onDone: (id: number) => void, + opts: grpc.TransportOptions + ) { + super(stream, onDone, opts); + this.channel = channel; + } + + public start(metadata: grpc.Metadata) { + const method = `/${this.opts.methodDefinition.service.serviceName}/${this.opts.methodDefinition.methodName}`; + const requestHeaders = new RequestHeaders(); + requestHeaders.setMethod(method); + requestHeaders.setMetadata(fromGRPCMetadata(metadata)); + + try { + this.channel.writeHeaders(this.stream, requestHeaders); + } catch (error) { + console.error('error writing headers', error); + this.closeWithRecvError(error as Error); + } + } + + public sendMessage(msgBytes?: Uint8Array) { + // skip frame header bytes + if (msgBytes) { + this.writeMessage(false, msgBytes.slice(5)); + return; + } + this.writeMessage(false, undefined); + } + + public resetStream() { + try { + this.channel.writeReset(this.stream); + } catch (error) { + console.error('error writing reset', error); + this.closeWithRecvError(error as Error); + } + } + + public finishSend() { + if (!this.opts.methodDefinition.requestStream) { + return; + } + this.writeMessage(true, undefined); + } + + public cancel() { + if (this.closed) { + return; + } + this.resetStream(); + } + + private writeMessage(eos: boolean, msgBytes?: Uint8Array) { + try { + let remMsgBytes = msgBytes; + if (!remMsgBytes || remMsgBytes.length === 0) { + const packet = new PacketMessage(); + packet.setEom(true); + const requestMessage = new RequestMessage(); + requestMessage.setHasMessage(Boolean(remMsgBytes)); + requestMessage.setPacketMessage(packet); + requestMessage.setEos(eos); + this.channel.writeMessage(this.stream, requestMessage); + return; + } + + while (remMsgBytes.length > 0) { + const amountToSend = Math.min( + remMsgBytes.length, + maxRequestMessagePacketDataSize + ); + const packet = new PacketMessage(); + packet.setData(remMsgBytes.slice(0, amountToSend)); + remMsgBytes = remMsgBytes.slice(amountToSend); + if (remMsgBytes.length === 0) { + packet.setEom(true); + } + const requestMessage = new RequestMessage(); + requestMessage.setHasMessage(Boolean(remMsgBytes)); + requestMessage.setPacketMessage(packet); + requestMessage.setEos(eos); + this.channel.writeMessage(this.stream, requestMessage); + } + } catch (error) { + console.error('error writing message', error); + this.closeWithRecvError(error as Error); + } + } + + public onResponse(resp: Response) { + switch (resp.getTypeCase()) { + case Response.TypeCase.HEADERS: { + if (this.headersReceived) { + this.closeWithRecvError(new Error('headers already received')); + return; + } + if (this.trailersReceived) { + this.closeWithRecvError(new Error('headers received after trailers')); + return; + } + const respHeaders = resp.getHeaders(); + if (respHeaders === undefined) { + this.closeWithRecvError(new Error('no headers in response')); + return; + } + this.processHeaders(respHeaders); + break; + } + case Response.TypeCase.MESSAGE: { + if (!this.headersReceived) { + this.closeWithRecvError(new Error('headers not yet received')); + return; + } + if (this.trailersReceived) { + this.closeWithRecvError(new Error('headers received after trailers')); + return; + } + const respMessage = resp.getMessage(); + if (respMessage === undefined) { + this.closeWithRecvError(new Error('no message in response')); + return; + } + this.processMessage(respMessage); + break; + } + case Response.TypeCase.TRAILERS: { + const respTrailers = resp.getTrailers(); + if (respTrailers === undefined) { + this.closeWithRecvError(new Error('no trailers in response')); + return; + } + this.processTrailers(respTrailers); + break; + } + default: { + console.error('unknown response type', resp.getTypeCase()); + break; + } + } + } + + private processHeaders(headers: ResponseHeaders) { + this.headersReceived = true; + this.opts.onHeaders(toGRPCMetadata(headers.getMetadata()), 200); + } + + private processMessage(msg: ResponseMessage) { + const pktMsg = msg.getPacketMessage(); + if (!pktMsg) { + return; + } + const result = super.processPacketMessage(pktMsg); + if (!result) { + return; + } + const chunk = new ArrayBuffer(result.length + 5); + new DataView(chunk, 1, 4).setUint32(0, result.length, false); + new Uint8Array(chunk, 5).set(result); + this.opts.onChunk(new Uint8Array(chunk)); + } + + private processTrailers(trailers: ResponseTrailers) { + this.trailersReceived = true; + const headers = toGRPCMetadata(trailers.getMetadata()); + let statusCode; + let statusMessage; + const status = trailers.getStatus(); + if (status) { + statusCode = status.getCode(); + statusMessage = status.getMessage(); + headers.set('grpc-status', `${status.getCode()}`); + headers.set('grpc-message', status.getMessage()); + } else { + statusCode = 0; + headers.set('grpc-status', '0'); + statusMessage = ''; + } + + const headerBytes = headersToBytes(headers); + const chunk = new ArrayBuffer(headerBytes.length + 5); + new DataView(chunk, 0, 1).setUint8(0, 128); + new DataView(chunk, 1, 4).setUint32(0, headerBytes.length, false); + new Uint8Array(chunk, 5).set(headerBytes); + this.opts.onChunk(new Uint8Array(chunk)); + if (statusCode === 0) { + this.closeWithRecvError(); + return; + } + this.closeWithRecvError(new GRPCError(statusCode, statusMessage)); + } +} + +// from https://github.com/improbable-eng/grpc-web/blob/6fb683f067bd56862c3a510bc5590b955ce46d2a/ts/src/ChunkParser.ts#L22 +export const encodeASCII = (input: string): Uint8Array => { + const encoded = new Uint8Array(input.length); + // eslint-disable-next-line no-plusplus + for (let i = 0; i !== input.length; i++) { + // eslint-disable-next-line unicorn/prefer-code-point + const charCode = input.charCodeAt(i); + if (!isValidHeaderAscii(charCode)) { + throw new Error('Metadata contains invalid ASCII'); + } + encoded[i] = charCode; + } + return encoded; +}; + +const isAllowedControlChars = (char: number) => + char === 0x9 || char === 0xa || char === 0xd; + +const isValidHeaderAscii = (val: number): boolean => { + return isAllowedControlChars(val) || (val >= 0x20 && val <= 0x7e); +}; + +const headersToBytes = (headers: grpc.Metadata): Uint8Array => { + let asString = ''; + // eslint-disable-next-line unicorn/no-array-for-each + headers.forEach((key: string, values: string[]) => { + asString += `${key}: ${values.join(', ')}\r\n`; + }); + return encodeASCII(asString); +}; + +// from https://github.com/jsmouret/grpc-over-webrtc/blob/45cd6d6cf516e78b1e262ea7aa741bc7a7a93dbc/client-improbable/src/grtc/webrtcclient.ts#L7 +const fromGRPCMetadata = (metadata?: grpc.Metadata): Metadata | undefined => { + if (!metadata) { + return undefined; + } + const result = new Metadata(); + const md = result.getMdMap(); + // eslint-disable-next-line unicorn/no-array-for-each + metadata.forEach((key: string, values: string[]) => { + const strings = new Strings(); + strings.setValuesList(values); + md.set(key, strings); + }); + if (result.getMdMap().getLength() === 0) { + return undefined; + } + return result; +}; + +const toGRPCMetadata = (metadata?: Metadata): grpc.Metadata => { + const result = new grpc.Metadata(); + if (metadata) { + for (const [key, entry] of metadata.getMdMap().entries()) { + result.append(key, entry.getValuesList()); + } + } + return result; +}; diff --git a/src/rpc/connection-closed-error.ts b/src/rpc/connection-closed-error.ts new file mode 100644 index 000000000..b0c866b81 --- /dev/null +++ b/src/rpc/connection-closed-error.ts @@ -0,0 +1,21 @@ +export class ConnectionClosedError extends Error { + public override readonly name = 'ConnectionClosedError'; + + constructor(msg: string) { + super(msg); + Object.setPrototypeOf(this, ConnectionClosedError.prototype); + } + + static isError(error: unknown): boolean { + if (error instanceof ConnectionClosedError) { + return true; + } + if (typeof error === 'string') { + return error === 'Response closed without headers'; + } + if (error instanceof Error) { + return error.message === 'Response closed without headers'; + } + return false; + } +} diff --git a/src/rpc/dial.ts b/src/rpc/dial.ts new file mode 100644 index 000000000..81e46a051 --- /dev/null +++ b/src/rpc/dial.ts @@ -0,0 +1,774 @@ +import { grpc } from '@improbable-eng/grpc-web'; +import type { ProtobufMessage } from '@improbable-eng/grpc-web/dist/typings/message'; +import { Code } from '../gen/google/rpc/code_pb'; +import { Status } from '../gen/google/rpc/status_pb'; +import { + AuthenticateRequest, + AuthenticateResponse, + AuthenticateToRequest, + AuthenticateToResponse, + Credentials as PBCredentials, +} from '../gen/proto/rpc/v1/auth_pb'; +import { + AuthService, + ExternalAuthService, +} from '../gen/proto/rpc/v1/auth_pb_service'; +import { + CallRequest, + CallResponse, + CallUpdateRequest, + CallUpdateResponse, + ICECandidate, + OptionalWebRTCConfigRequest, + OptionalWebRTCConfigResponse, + WebRTCConfig, +} from '../gen/proto/rpc/webrtc/v1/signaling_pb'; +import { SignalingService } from '../gen/proto/rpc/webrtc/v1/signaling_pb_service'; +import { ClientChannel } from './client-channel'; +import { ConnectionClosedError } from './connection-closed-error'; +import { addSdpFields, newPeerConnectionForClient } from './peer'; + +import type { CrossBrowserHttpTransportInit } from '@improbable-eng/grpc-web/dist/typings/transports/http/http'; +import { atob, btoa } from './polyfills'; + +export interface DialOptions { + authEntity?: string | undefined; + credentials?: Credentials | undefined; + webrtcOptions?: DialWebRTCOptions; + externalAuthAddress?: string | undefined; + externalAuthToEntity?: string | undefined; + + // `accessToken` allows a pre-authenticated client to dial with + // an authorization header. Direct dial will have the access token + // appended to the "Authorization: Bearer" header. WebRTC dial will + // appened it to the signaling server communication + // + // If enabled, other auth options have no affect. Eg. authEntity, credentials, + // externalAuthAddress, externalAuthToEntity, webrtcOptions.signalingAccessToken + accessToken?: string | undefined; + + // set timeout in milliseconds for dialing. + dialTimeout?: number | undefined; +} + +export interface DialWebRTCOptions { + disableTrickleICE: boolean; + rtcConfig?: RTCConfiguration; + + // signalingAuthEntity is the entity to authenticate as to the signaler. + signalingAuthEntity?: string; + + // signalingExternalAuthAddress is the address to perform external auth yet. + // This is unlikely to be needed since the signaler is typically in the same + // place where authentication happens. + signalingExternalAuthAddress?: string; + + // signalingExternalAuthToEntity is the entity to authenticate for after + // externally authenticating. + // This is unlikely to be needed since the signaler is typically in the same + // place where authentication happens. + signalingExternalAuthToEntity?: string; + + // signalingCredentials are used to authenticate the request to the signaling server. + signalingCredentials?: Credentials; + + // `signalingAccessToken` allows a pre-authenticated client to dial with + // an authorization header to the signaling server. This skips the Authenticate() + // request to the singaling server or external auth but does not skip the + // AuthenticateTo() request to retrieve the credentials at the external auth + // endpoint. + // + // If enabled, other auth options have no affect. Eg. authEntity, credentials, signalingAuthEntity, signalingCredentials. + signalingAccessToken?: string; + + // `additionalSDPValues` is a collection of additional SDP values that we want to pass into the connection's call request. + additionalSdpFields?: Record; +} + +export interface Credentials { + type: string; + payload: string; +} + +export const dialDirect = async ( + address: string, + opts?: DialOptions +): Promise => { + validateDialOptions(opts); + const defaultFactory = ( + transportOpts: grpc.TransportOptions + ): grpc.Transport => { + const transFact: ( + init: CrossBrowserHttpTransportInit + ) => grpc.TransportFactory = + window.VIAM?.GRPC_TRANSPORT_FACTORY ?? grpc.CrossBrowserHttpTransport; + return transFact({ withCredentials: false })(transportOpts); + }; + + // Client already has access token with no external auth, skip Authenticate process. + if ( + opts?.accessToken && + !(opts.externalAuthAddress && opts.externalAuthToEntity) + ) { + const md = new grpc.Metadata(); + md.set('authorization', `Bearer ${opts.accessToken}`); + return (transportOpts: grpc.TransportOptions): grpc.Transport => { + return new AuthenticatedTransport(transportOpts, defaultFactory, md); + }; + } + + if (!opts || (!opts.credentials && !opts.accessToken)) { + return defaultFactory; + } + + return makeAuthenticatedTransportFactory(address, defaultFactory, opts); +}; + +const makeAuthenticatedTransportFactory = async ( + address: string, + defaultFactory: grpc.TransportFactory, + opts: DialOptions +): Promise => { + let accessToken = ''; + // eslint-disable-next-line sonarjs/cognitive-complexity + const getExtraMetadata = async (): Promise => { + const md = new grpc.Metadata(); + // TODO(GOUT-10): handle expiration + if (accessToken === '') { + let thisAccessToken = ''; + + let pResolve: (value: grpc.Metadata) => void; + let pReject: (reason?: unknown) => void; + + if (!opts.accessToken || opts.accessToken === '') { + const request = new AuthenticateRequest(); + request.setEntity(opts.authEntity ?? address.replace(/^.*:\/\//u, '')); + if (opts.credentials) { + const creds = new PBCredentials(); + creds.setType(opts.credentials.type); + creds.setPayload(opts.credentials.payload); + request.setCredentials(creds); + } + + const done = new Promise((resolve, reject) => { + pResolve = resolve; + pReject = reject; + }); + + grpc.invoke(AuthService.Authenticate, { + request, + host: opts.externalAuthAddress ?? address, + transport: defaultFactory, + onMessage: (message: AuthenticateResponse) => { + thisAccessToken = message.getAccessToken(); + }, + onEnd: ( + code: grpc.Code, + msg: string | undefined, + _trailers: grpc.Metadata + ) => { + if (code === grpc.Code.OK) { + pResolve(md); + } else { + pReject(msg); + } + }, + }); + await done; + } else { + thisAccessToken = opts.accessToken; + } + + // eslint-disable-next-line require-atomic-updates + accessToken = thisAccessToken; + + if (opts.externalAuthAddress && opts.externalAuthToEntity) { + const authMd = new grpc.Metadata(); + authMd.set('authorization', `Bearer ${accessToken}`); + + const done = new Promise((resolve, reject) => { + pResolve = resolve; + pReject = reject; + }); + thisAccessToken = ''; + + const request = new AuthenticateToRequest(); + request.setEntity(opts.externalAuthToEntity); + grpc.invoke(ExternalAuthService.AuthenticateTo, { + request, + host: opts.externalAuthAddress, + transport: defaultFactory, + metadata: authMd, + onMessage: (message: AuthenticateToResponse) => { + thisAccessToken = message.getAccessToken(); + }, + onEnd: ( + code: grpc.Code, + msg: string | undefined, + _trailers: grpc.Metadata + ) => { + if (code === grpc.Code.OK) { + pResolve(authMd); + } else { + pReject(msg); + } + }, + }); + await done; + // eslint-disable-next-line require-atomic-updates + accessToken = thisAccessToken; + } + } + md.set('authorization', `Bearer ${accessToken}`); + return md; + }; + const extraMd = await getExtraMetadata(); + return (transportOpts: grpc.TransportOptions): grpc.Transport => { + return new AuthenticatedTransport(transportOpts, defaultFactory, extraMd); + }; +}; + +class AuthenticatedTransport implements grpc.Transport { + protected readonly opts: grpc.TransportOptions; + protected readonly transport: grpc.Transport; + protected readonly extraMetadata: grpc.Metadata; + + constructor( + opts: grpc.TransportOptions, + defaultFactory: grpc.TransportFactory, + extraMetadata: grpc.Metadata + ) { + this.opts = opts; + this.extraMetadata = extraMetadata; + this.transport = defaultFactory(opts); + } + + public start(metadata: grpc.Metadata) { + // eslint-disable-next-line unicorn/no-array-for-each + this.extraMetadata.forEach((key: string, values: string | string[]) => { + metadata.set(key, values); + }); + this.transport.start(metadata); + } + + public sendMessage(msgBytes: Uint8Array) { + this.transport.sendMessage(msgBytes); + } + + public finishSend() { + this.transport.finishSend(); + } + + public cancel() { + this.transport.cancel(); + } +} + +export interface WebRTCConnection { + transportFactory: grpc.TransportFactory; + peerConnection: RTCPeerConnection; + dataChannel: RTCDataChannel; +} + +const getOptionalWebRTCConfig = async ( + signalingAddress: string, + host: string, + opts?: DialOptions +): Promise => { + const optsCopy = { ...opts } as DialOptions; + const directTransport = await dialDirect(signalingAddress, optsCopy); + + let pResolve: (value: WebRTCConfig) => void; + let pReject: (reason?: unknown) => void; + + let result: WebRTCConfig | undefined; + const done = new Promise((resolve, reject) => { + pResolve = resolve; + pReject = reject; + }); + + grpc.unary(SignalingService.OptionalWebRTCConfig, { + request: new OptionalWebRTCConfigRequest(), + metadata: { + 'rpc-host': host, + }, + host: signalingAddress, + transport: directTransport, + onEnd: (resp: grpc.UnaryOutput) => { + const { status, statusMessage, message } = resp; + if (status === grpc.Code.OK && message) { + result = message.getConfig(); + if (!result) { + pResolve(new WebRTCConfig()); + return; + } + pResolve(result); + // In some cases the `OptionalWebRTCConfig` method seems to be unimplemented, even + // when building `viam-server` from latest. Falling back to a default config seems + // harmless in these cases, and allows connection to continue. + } else if (status === grpc.Code.Unimplemented) { + pResolve(new WebRTCConfig()); + } else { + pReject(statusMessage); + } + }, + }); + + await done; + + if (!result) { + throw new Error('no config'); + } + return result; +}; + +// dialWebRTC makes a connection to given host by signaling with the address provided. A Promise is returned +// upon successful connection that contains a transport factory to use with gRPC client as well as the WebRTC +// PeerConnection itself. Care should be taken with the PeerConnection and is currently returned for experimental +// use. +// TODO(GOUT-7): figure out decent way to handle reconnect on connection termination +export const dialWebRTC = async ( + signalingAddress: string, + host: string, + opts?: DialOptions + // eslint-disable-next-line sonarjs/cognitive-complexity +): Promise => { + const usableSignalingAddress = signalingAddress.replace(/\/$/u, ''); + validateDialOptions(opts); + + // TODO(RSDK-2836): In general, this logic should be in parity with the golang implementation. + // https://github.com/viamrobotics/goutils/blob/main/rpc/wrtc_client.go#L160-L175 + const config = await getOptionalWebRTCConfig( + usableSignalingAddress, + host, + opts + ); + const additionalIceServers: RTCIceServer[] = config + .toObject() + .additionalIceServersList.map((ice) => { + return { + urls: ice.urlsList, + credential: ice.credential, + username: ice.username, + }; + }); + + const usableOpts = opts ?? {}; + + let webrtcOpts: DialWebRTCOptions; + if (usableOpts.webrtcOptions) { + // RSDK-8715: We deep copy here to avoid mutating the input config's `rtcConfig.iceServers` + // list. + webrtcOpts = JSON.parse( + JSON.stringify(usableOpts.webrtcOptions) + ) as DialWebRTCOptions; + if (webrtcOpts.rtcConfig) { + webrtcOpts.rtcConfig.iceServers = [ + ...(webrtcOpts.rtcConfig.iceServers ?? []), + ...additionalIceServers, + ]; + } else { + webrtcOpts.rtcConfig = { iceServers: additionalIceServers }; + } + } else { + // use additional webrtc config as default + webrtcOpts = { + disableTrickleICE: config.getDisableTrickle(), + rtcConfig: { + iceServers: additionalIceServers, + }, + }; + } + + const { pc, dc } = await newPeerConnectionForClient( + webrtcOpts.disableTrickleICE, + webrtcOpts.rtcConfig, + webrtcOpts.additionalSdpFields + ); + let successful = false; + + try { + // replace auth entity and creds + let optsCopy = usableOpts; + optsCopy = { ...usableOpts } as DialOptions; + + if (!usableOpts.accessToken) { + optsCopy.authEntity = usableOpts.webrtcOptions?.signalingAuthEntity; + if (!optsCopy.authEntity) { + optsCopy.authEntity = optsCopy.externalAuthAddress + ? usableOpts.externalAuthAddress?.replace(/^.*:\/\//u, '') + : usableSignalingAddress.replace(/^.*:\/\//u, ''); + } + optsCopy.credentials = usableOpts.webrtcOptions?.signalingCredentials; + optsCopy.accessToken = usableOpts.webrtcOptions?.signalingAccessToken; + } + + optsCopy.externalAuthAddress = + usableOpts.webrtcOptions?.signalingExternalAuthAddress; + optsCopy.externalAuthToEntity = + usableOpts.webrtcOptions?.signalingExternalAuthToEntity; + + const directTransport = await dialDirect(usableSignalingAddress, optsCopy); + const client = grpc.client(SignalingService.Call, { + host: usableSignalingAddress, + transport: directTransport, + }); + + let uuid = ''; + // only send once since exchange may end or ICE may end + let sentDoneOrErrorOnce = false; + const sendError = (err: string) => { + if (sentDoneOrErrorOnce) { + return; + } + sentDoneOrErrorOnce = true; + const callRequestUpdate = new CallUpdateRequest(); + callRequestUpdate.setUuid(uuid); + const status = new Status(); + status.setCode(Code.UNKNOWN); + status.setMessage(err); + callRequestUpdate.setError(status); + grpc.unary(SignalingService.CallUpdate, { + request: callRequestUpdate, + metadata: { + 'rpc-host': host, + }, + host: usableSignalingAddress, + transport: directTransport, + onEnd: (output: grpc.UnaryOutput) => { + const { status: grpcStatus, statusMessage, message } = output; + if (grpcStatus === grpc.Code.OK && message) { + return; + } + console.error(statusMessage); + }, + }); + }; + const sendDone = () => { + if (sentDoneOrErrorOnce) { + return; + } + sentDoneOrErrorOnce = true; + const callRequestUpdate = new CallUpdateRequest(); + callRequestUpdate.setUuid(uuid); + callRequestUpdate.setDone(true); + grpc.unary(SignalingService.CallUpdate, { + request: callRequestUpdate, + metadata: { + 'rpc-host': host, + }, + host: usableSignalingAddress, + transport: directTransport, + onEnd: (output: grpc.UnaryOutput) => { + const { status, statusMessage, message } = output; + if (status === grpc.Code.OK && message) { + return; + } + console.error(statusMessage); + }, + }); + }; + + let pResolve: (value: unknown) => void; + const remoteDescSet = new Promise((resolve) => { + pResolve = resolve; + }); + let exchangeDone = false; + if (!webrtcOpts.disableTrickleICE) { + // set up offer + const offerDesc = await pc.createOffer({}); + + let iceComplete = false; + let numCallUpdates = 0; + let maxCallUpdateDuration = 0; + let totalCallUpdateDuration = 0; + + pc.addEventListener('iceconnectionstatechange', () => { + if (pc.iceConnectionState !== 'completed' || numCallUpdates === 0) { + return; + } + const averageCallUpdateDuration = + totalCallUpdateDuration / numCallUpdates; + console.groupCollapsed('Caller update statistics'); + console.table({ + num_updates: numCallUpdates, + average_duration: `${averageCallUpdateDuration}ms`, + max_duration: `${maxCallUpdateDuration}ms`, + }); + console.groupEnd(); + }); + pc.addEventListener( + 'icecandidate', + async (event: { candidate: RTCIceCandidateInit | null }) => { + await remoteDescSet; + if (exchangeDone) { + return; + } + + if (event.candidate === null) { + iceComplete = true; + sendDone(); + return; + } + + if (event.candidate.candidate !== undefined) { + console.debug(`gathered local ICE ${event.candidate.candidate}`); + } + const iProto = iceCandidateToProto(event.candidate); + const callRequestUpdate = new CallUpdateRequest(); + callRequestUpdate.setUuid(uuid); + callRequestUpdate.setCandidate(iProto); + const callUpdateStart = new Date(); + grpc.unary(SignalingService.CallUpdate, { + request: callRequestUpdate, + metadata: { + 'rpc-host': host, + }, + host: usableSignalingAddress, + transport: directTransport, + onEnd: (output: grpc.UnaryOutput) => { + const { status, statusMessage, message } = output; + if (status === grpc.Code.OK && message) { + numCallUpdates += 1; + const callUpdateEnd = new Date(); + const callUpdateDuration = + callUpdateEnd.getTime() - callUpdateStart.getTime(); + if (callUpdateDuration > maxCallUpdateDuration) { + maxCallUpdateDuration = callUpdateDuration; + } + totalCallUpdateDuration += callUpdateDuration; + return; + } + if (exchangeDone || iceComplete) { + return; + } + console.error('error sending candidate', statusMessage); + }, + }); + } + ); + + await pc.setLocalDescription(offerDesc); + } + + // initialize cc here so we can use it in the callbacks + const cc = new ClientChannel(pc, dc); + + // set timeout for dial attempt if a timeout is specified + if (usableOpts.dialTimeout) { + setTimeout(() => { + if (!successful) { + cc.close(); + } + }, usableOpts.dialTimeout); + } + + let haveInit = false; + // TS says that CallResponse isn't a valid type here. More investigation required. + client.onMessage(async (message: ProtobufMessage) => { + const response = message as CallResponse; + + if (response.hasInit()) { + if (haveInit) { + sendError('got init stage more than once'); + return; + } + const init = response.getInit(); + if (init === undefined) { + sendError('no init in response'); + return; + } + haveInit = true; + uuid = response.getUuid(); + + const remoteSDP = new RTCSessionDescription( + JSON.parse(atob(init.getSdp())) + ); + if (cc.isClosed()) { + sendError('client channel is closed'); + return; + } + await pc.setRemoteDescription(remoteSDP); + + pResolve(true); + + if (webrtcOpts.disableTrickleICE) { + exchangeDone = true; + sendDone(); + } + } else if (response.hasUpdate()) { + if (!haveInit) { + sendError('got update stage before init stage'); + return; + } + if (response.getUuid() !== uuid) { + sendError(`uuid mismatch; have=${response.getUuid()} want=${uuid}`); + return; + } + const update = response.getUpdate(); + if (update === undefined) { + sendError('no update in response'); + return; + } + const cand = update.getCandidate(); + if (cand === undefined) { + return; + } + const iceCand = iceCandidateFromProto(cand); + if (iceCand.candidate !== undefined) { + console.debug(`received remote ICE ${iceCand.candidate}`); + } + try { + await pc.addIceCandidate(iceCand); + } catch (error) { + sendError(JSON.stringify(error)); + } + } else { + sendError('unknown CallResponse stage'); + } + }); + + let clientEndResolve: () => void; + let clientEndReject: (reason?: unknown) => void; + const clientEnd = new Promise((resolve, reject) => { + clientEndResolve = resolve; + clientEndReject = reject; + }); + client.onEnd( + (status: grpc.Code, statusMessage: string, _trailers: grpc.Metadata) => { + if (status === grpc.Code.OK) { + clientEndResolve(); + return; + } + if (statusMessage === 'Response closed without headers') { + clientEndReject(new ConnectionClosedError('failed to dial')); + return; + } + if (cc.isClosed()) { + clientEndReject( + new ConnectionClosedError('client channel is closed') + ); + return; + } + console.error(statusMessage); + clientEndReject(statusMessage); + } + ); + client.start({ 'rpc-host': host }); + + const callRequest = new CallRequest(); + const description = addSdpFields( + pc.localDescription, + usableOpts.webrtcOptions?.additionalSdpFields + ); + const encodedSDP = btoa(JSON.stringify(description)); + callRequest.setSdp(encodedSDP); + if (webrtcOpts.disableTrickleICE) { + callRequest.setDisableTrickle(webrtcOpts.disableTrickleICE); + } + client.send(callRequest); + + cc.ready + .then(() => clientEndResolve()) + .catch((error) => clientEndReject(error)); + await clientEnd; + await cc.ready; + exchangeDone = true; + sendDone(); + + successful = true; + return { + transportFactory: cc.transportFactory(), + peerConnection: pc, + dataChannel: dc, + }; + } finally { + if (!successful) { + pc.close(); + } + } +}; + +const iceCandidateFromProto = (i: ICECandidate): RTCIceCandidateInit => { + const candidate: RTCIceCandidateInit = { + candidate: i.getCandidate(), + }; + if (i.hasSdpMid()) { + candidate.sdpMid = i.getSdpMid(); + } + if (i.hasSdpmLineIndex()) { + candidate.sdpMLineIndex = i.getSdpmLineIndex(); + } + if (i.hasUsernameFragment()) { + candidate.usernameFragment = i.getUsernameFragment(); + } + return candidate; +}; + +const iceCandidateToProto = (i: RTCIceCandidateInit): ICECandidate => { + const candidate = new ICECandidate(); + if (i.candidate) { + candidate.setCandidate(i.candidate); + } + if (i.sdpMid) { + candidate.setSdpMid(i.sdpMid); + } + if (i.sdpMLineIndex) { + candidate.setSdpmLineIndex(i.sdpMLineIndex); + } + if (i.usernameFragment) { + candidate.setUsernameFragment(i.usernameFragment); + } + return candidate; +}; + +// eslint-disable-next-line sonarjs/cognitive-complexity +const validateDialOptions = (opts?: DialOptions) => { + if (!opts) { + return; + } + + if (opts.accessToken && opts.accessToken.length > 0) { + if (opts.authEntity) { + throw new Error('cannot set authEntity with accessToken'); + } + + if (opts.credentials) { + throw new Error('cannot set credentials with accessToken'); + } + + if (opts.webrtcOptions) { + if (opts.webrtcOptions.signalingAccessToken) { + throw new Error( + 'cannot set webrtcOptions.signalingAccessToken with accessToken' + ); + } + if (opts.webrtcOptions.signalingAuthEntity) { + throw new Error( + 'cannot set webrtcOptions.signalingAuthEntity with accessToken' + ); + } + if (opts.webrtcOptions.signalingCredentials) { + throw new Error( + 'cannot set webrtcOptions.signalingCredentials with accessToken' + ); + } + } + } + + if ( + opts.webrtcOptions?.signalingAccessToken && + opts.webrtcOptions.signalingAccessToken.length > 0 + ) { + if (opts.webrtcOptions.signalingAuthEntity) { + throw new Error( + 'cannot set webrtcOptions.signalingAuthEntity with webrtcOptions.signalingAccessToken' + ); + } + if (opts.webrtcOptions.signalingCredentials) { + throw new Error( + 'cannot set webrtcOptions.signalingCredentials with webrtcOptions.signalingAccessToken' + ); + } + } +}; diff --git a/src/rpc/grpc-error.ts b/src/rpc/grpc-error.ts new file mode 100644 index 000000000..3c5723048 --- /dev/null +++ b/src/rpc/grpc-error.ts @@ -0,0 +1,12 @@ +export class GRPCError extends Error { + public override readonly name = 'GRPCError'; + public readonly code: number; + public readonly grpcMessage: string; + + constructor(code: number, message: string) { + super(`Code=${code} Message=${message}`); + this.code = code; + this.grpcMessage = message; + Object.setPrototypeOf(this, GRPCError.prototype); + } +} diff --git a/src/rpc/index.ts b/src/rpc/index.ts new file mode 100644 index 000000000..8097328a0 --- /dev/null +++ b/src/rpc/index.ts @@ -0,0 +1,25 @@ +import type { TransportFactory } from '@improbable-eng/grpc-web/dist/typings/transports/Transport'; +import type { CrossBrowserHttpTransportInit } from '@improbable-eng/grpc-web/dist/typings/transports/http/http'; + +declare global { + // eslint-disable-next-line vars-on-top,no-var + var VIAM: + | { + GRPC_TRANSPORT_FACTORY?: ( + opts: CrossBrowserHttpTransportInit + ) => TransportFactory; + } + | undefined; +} + +export { + dialDirect, + dialWebRTC, + type Credentials, + type DialOptions, + type DialWebRTCOptions, + type WebRTCConnection, +} from './dial'; + +export { ConnectionClosedError } from './connection-closed-error'; +export { GRPCError } from './grpc-error'; diff --git a/src/rpc/peer.ts b/src/rpc/peer.ts new file mode 100644 index 000000000..fb252d07e --- /dev/null +++ b/src/rpc/peer.ts @@ -0,0 +1,120 @@ +import { atob, btoa } from './polyfills'; + +interface ReadyPeer { + pc: RTCPeerConnection; + dc: RTCDataChannel; +} + +export const addSdpFields = ( + localDescription?: RTCSessionDescription | null, + sdpFields?: Record +) => { + const description = { + sdp: localDescription?.sdp, + type: localDescription?.type, + }; + if (sdpFields) { + for (const key of Object.keys(sdpFields)) { + description.sdp = [ + description.sdp, + `a=${key}:${sdpFields[key]}\r\n`, + ].join(''); + } + } + return description; +}; + +export const newPeerConnectionForClient = async ( + disableTrickle: boolean, + rtcConfig?: RTCConfiguration, + additionalSdpFields?: Record +): Promise => { + const usableRTCConfig = rtcConfig ?? { + iceServers: [ + { + urls: 'stun:global.stun.twilio.com:3478', + }, + ], + }; + const peerConnection = new RTCPeerConnection(usableRTCConfig); + + let pResolve: (value: ReadyPeer) => void; + const result = new Promise((resolve) => { + pResolve = resolve; + }); + const dataChannel = peerConnection.createDataChannel('data', { + id: 0, + negotiated: true, + ordered: true, + }); + dataChannel.binaryType = 'arraybuffer'; + + const negotiationChannel = peerConnection.createDataChannel('negotiation', { + id: 1, + negotiated: true, + ordered: true, + }); + negotiationChannel.binaryType = 'arraybuffer'; + + let negOpen = false; + negotiationChannel.addEventListener('open', () => { + negOpen = true; + }); + negotiationChannel.addEventListener( + 'message', + async (event: MessageEvent) => { + try { + const description = new RTCSessionDescription( + JSON.parse(atob(event.data)) as RTCSessionDescriptionInit + ); + + // we are always polite and will never ignore an offer + + await peerConnection.setRemoteDescription(description); + + if (description.type === 'offer') { + await peerConnection.setLocalDescription(); + const newDescription = addSdpFields( + peerConnection.localDescription, + additionalSdpFields + ); + negotiationChannel.send(btoa(JSON.stringify(newDescription))); + } + } catch (error) { + console.error(error); + } + } + ); + + peerConnection.addEventListener('negotiationneeded', async () => { + if (!negOpen) { + return; + } + try { + await peerConnection.setLocalDescription(); + const newDescription = addSdpFields( + peerConnection.localDescription, + additionalSdpFields + ); + negotiationChannel.send(btoa(JSON.stringify(newDescription))); + } catch (error) { + console.error(error); + } + }); + + if (!disableTrickle) { + return { pc: peerConnection, dc: dataChannel }; + } + // set up offer + const offerDesc = await peerConnection.createOffer({}); + await peerConnection.setLocalDescription(offerDesc); + + peerConnection.addEventListener('icecandidate', (event) => { + if (event.candidate !== null) { + return; + } + pResolve({ pc: peerConnection, dc: dataChannel }); + }); + + return result; +}; diff --git a/src/rpc/polyfills.ts b/src/rpc/polyfills.ts new file mode 100644 index 000000000..0fc0f3fe3 --- /dev/null +++ b/src/rpc/polyfills.ts @@ -0,0 +1,49 @@ +const chars = + 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/='; + +/* eslint-disable no-bitwise, unicorn/prefer-code-point, no-plusplus, require-unicode-regexp */ +export const btoa = (input = '') => { + const str = input; + let output = ''; + + for ( + let block = 0, charCode, i = 0, map = chars; + str.charAt(Math.trunc(i)) || ((map = '='), i % 1); + output += map.charAt(63 & (block >> (8 - (i % 1) * 8))) + ) { + charCode = str.charCodeAt((i += 3 / 4)); + + if (charCode > 0xff) { + throw new Error( + "'btoa' failed: The string to be encoded contains characters outside of the Latin1 range." + ); + } + + block = (block << 8) | charCode; + } + + return output; +}; + +export const atob = (input = '') => { + const str = input.replace(/=+$/, ''); // eslint-disable-line no-div-regex + let output = ''; + + if (str.length % 4 === 1) { + throw new Error( + "'atob' failed: The string to be decoded is not correctly encoded." + ); + } + for ( + let bc = 0, bs = 0, buffer, i = 0; + (buffer = str.charAt(i++)); // eslint-disable-line no-cond-assign + ~buffer && ((bs = bc % 4 ? bs * 64 + buffer : buffer), bc++ % 4) + ? (output += String.fromCharCode(255 & (bs >> ((-2 * bc) & 6)))) + : 0 + ) { + buffer = chars.indexOf(buffer); + } + + return output; +}; +/* eslint-enable no-bitwise, unicorn/prefer-code-point, no-plusplus, require-unicode-regexp */