From 47a3af1b2310cf509d0db69b1a25b25f135e8dc8 Mon Sep 17 00:00:00 2001
From: Puja Jagani <puja.jagani93@gmail.com>
Date: Fri, 4 Apr 2025 18:00:53 +0530
Subject: [PATCH 1/3] [js][bidi] Add request and response handler

---
 javascript/selenium-webdriver/bidi/network.js |  16 +-
 .../selenium-webdriver/bidi/networkTypes.js   |  11 +
 javascript/selenium-webdriver/lib/http.js     |  10 +-
 javascript/selenium-webdriver/lib/network.js  | 175 +++++++++++++++
 .../selenium-webdriver/lib/test/fileserver.js |  29 +++
 .../test/lib/webdriver_network_test.js        | 202 ++++++++++++++++++
 6 files changed, 433 insertions(+), 10 deletions(-)

diff --git a/javascript/selenium-webdriver/bidi/network.js b/javascript/selenium-webdriver/bidi/network.js
index 8e860b33f2bdd..4f40eed6604c6 100644
--- a/javascript/selenium-webdriver/bidi/network.js
+++ b/javascript/selenium-webdriver/bidi/network.js
@@ -154,26 +154,26 @@ class Network {
 
     this.ws = await this.bidi.socket
     this.ws.on('message', (event) => {
-      const { params } = JSON.parse(Buffer.from(event.toString()))
+      const { method, params } = JSON.parse(Buffer.from(event.toString()))
       if (params) {
         let response = null
-        if ('initiator' in params) {
-          response = new BeforeRequestSent(
+        if ('request' in params && 'response' in params) {
+          response = new ResponseStarted(
             params.context,
             params.navigation,
             params.redirectCount,
             params.request,
             params.timestamp,
-            params.initiator,
+            params.response,
           )
-        } else if ('response' in params) {
-          response = new ResponseStarted(
+        } else if ('initiator' in params && !('response' in params)) {
+          response = new BeforeRequestSent(
             params.context,
             params.navigation,
             params.redirectCount,
             params.request,
             params.timestamp,
-            params.response,
+            params.initiator,
           )
         } else if ('errorText' in params) {
           response = new FetchError(
@@ -185,7 +185,7 @@ class Network {
             params.errorText,
           )
         }
-        this.invokeCallbacks(eventType, response)
+        this.invokeCallbacks(method, response)
       }
     })
     return id
diff --git a/javascript/selenium-webdriver/bidi/networkTypes.js b/javascript/selenium-webdriver/bidi/networkTypes.js
index 8a62aaa5d3964..5a2d4ea35a736 100644
--- a/javascript/selenium-webdriver/bidi/networkTypes.js
+++ b/javascript/selenium-webdriver/bidi/networkTypes.js
@@ -118,6 +118,17 @@ class Header {
   get value() {
     return this._value
   }
+
+  /**
+   * Converts the Header to a map.
+   * @returns {Map<string, string>} A map representation of the Header.
+   */
+  asMap() {
+    const map = new Map()
+    map.set('name', this._name)
+    map.set('value', Object.fromEntries(this._value.asMap()))
+    return map
+  }
 }
 
 /**
diff --git a/javascript/selenium-webdriver/lib/http.js b/javascript/selenium-webdriver/lib/http.js
index d09e4fef8bbf9..18d16f7caed19 100644
--- a/javascript/selenium-webdriver/lib/http.js
+++ b/javascript/selenium-webdriver/lib/http.js
@@ -88,12 +88,18 @@ class Request {
    * @param {string} method The HTTP method to use for the request.
    * @param {string} path The path on the server to send the request to.
    * @param {Object=} opt_data This request's non-serialized JSON payload data.
+   * @param {Map<string, string>} [headers=new Map()] - The optional headers as a Map.
    */
-  constructor(method, path, opt_data) {
+  constructor(method, path, opt_data, headers = new Map()) {
     this.method = /** string */ method
     this.path = /** string */ path
     this.data = /** Object */ opt_data
-    this.headers = /** !Map<string, string> */ new Map([['Accept', 'application/json; charset=utf-8']])
+
+    if (headers.size > 0) {
+      this.headers = headers
+    } else {
+      this.headers = /** !Map<string, string> */ new Map([['Accept', 'application/json; charset=utf-8']])
+    }
   }
 
   /** @override */
diff --git a/javascript/selenium-webdriver/lib/network.js b/javascript/selenium-webdriver/lib/network.js
index cfc5873804d53..bdca20c998cdc 100644
--- a/javascript/selenium-webdriver/lib/network.js
+++ b/javascript/selenium-webdriver/lib/network.js
@@ -18,12 +18,18 @@
 const { Network: getNetwork } = require('../bidi/network')
 const { InterceptPhase } = require('../bidi/interceptPhase')
 const { AddInterceptParameters } = require('../bidi/addInterceptParameters')
+const { ContinueRequestParameters } = require('../bidi/continueRequestParameters')
+const { ProvideResponseParameters } = require('../bidi/provideResponseParameters')
+const { Request } = require('./http')
+const { BytesValue, Header } = require('../bidi/networkTypes')
 
 class Network {
   #callbackId = 0
   #driver
   #network
   #authHandlers = new Map()
+  #requestHandlers = new Map()
+  #responseHandlers = new Map()
 
   constructor(driver) {
     this.#driver = driver
@@ -43,6 +49,8 @@ class Network {
 
     await this.#network.addIntercept(new AddInterceptParameters(InterceptPhase.AUTH_REQUIRED))
 
+    await this.#network.addIntercept(new AddInterceptParameters(InterceptPhase.BEFORE_REQUEST_SENT))
+
     await this.#network.authRequired(async (event) => {
       const requestId = event.request.request
       const uri = event.request.url
@@ -54,6 +62,76 @@ class Network {
 
       await this.#network.continueWithAuthNoCredentials(requestId)
     })
+
+    await this.#network.beforeRequestSent(async (event) => {
+      const requestId = event.request.request
+      const requestData = event.request
+
+      // Build the original request from the intercepted request details.
+      const originalRequest = new Request(requestData.method, requestData.url, null, new Map(requestData.headers))
+
+      let requestHandler = this.getRequestHandler(originalRequest)
+      let responseHandler = this.getResponseHandler(originalRequest)
+
+      // Populate the headers of the original request.
+      // Body is not available as part of WebDriver Spec, hence we cannot add that or use that.
+
+      const continueRequestParams = new ContinueRequestParameters(requestId)
+
+      // If a response handler exists, we mock the response instead of modifying the outgoing request
+      if (responseHandler !== null) {
+        const modifiedResponse = await responseHandler()
+
+        const provideResponseParams = new ProvideResponseParameters(requestId)
+        provideResponseParams.statusCode(modifiedResponse.status)
+
+        // Convert headers
+        if (modifiedResponse.headers.size > 0) {
+          const headers = []
+
+          modifiedResponse.headers.forEach((value, key) => {
+            headers.push(new Header(key, new BytesValue('string', value)))
+          })
+          provideResponseParams.headers(headers)
+        }
+
+        // Convert body if available
+        if (modifiedResponse.body && modifiedResponse.body.length > 0) {
+          provideResponseParams.body(new BytesValue('string', modifiedResponse.body))
+        }
+
+        await this.#network.provideResponse(provideResponseParams)
+        return
+      }
+
+      // If request handler exists, modify the request
+      if (requestHandler !== null) {
+        const modifiedRequest = requestHandler(originalRequest)
+
+        continueRequestParams.method(modifiedRequest.method)
+
+        if (originalRequest.path !== modifiedRequest.path) {
+          continueRequestParams.url(modifiedRequest.path)
+        }
+
+        // Convert headers
+        if (modifiedRequest.headers.size > 0) {
+          const headers = []
+
+          modifiedRequest.headers.forEach((value, key) => {
+            headers.push(new Header(key, new BytesValue('string', value)))
+          })
+          continueRequestParams.headers(headers)
+        }
+
+        if (modifiedRequest.data && modifiedRequest.data.length > 0) {
+          continueRequestParams.body(new BytesValue('string', modifiedRequest.data))
+        }
+      }
+
+      // Continue with the modified or original request
+      await this.#network.continueRequest(continueRequestParams)
+    })
   }
 
   getAuthCredentials(uri) {
@@ -64,6 +142,27 @@ class Network {
     }
     return null
   }
+
+  getRequestHandler(req) {
+    for (let [, value] of this.#requestHandlers) {
+      const filter = value.filter
+      if (filter(req)) {
+        return value.handler
+      }
+    }
+    return null
+  }
+
+  getResponseHandler(req) {
+    for (let [, value] of this.#responseHandlers) {
+      const filter = value.filter
+      if (filter(req)) {
+        return value.handler
+      }
+    }
+    return null
+  }
+
   async addAuthenticationHandler(username, password, uri = '//') {
     await this.#init()
 
@@ -86,6 +185,82 @@ class Network {
   async clearAuthenticationHandlers() {
     this.#authHandlers.clear()
   }
+
+  /**
+   * Adds a request handler that filters requests based on a predicate function.
+   * @param {Function} filter - A function that takes an HTTP request and returns true or false.
+   * @param {Function} handler - A function that takes an HTTP request and returns a modified request.
+   * @returns {number} - A unique handler ID.
+   * @throws {Error} - If filter is not a function or handler does not return a request.
+   */
+  async addRequestHandler(filter, handler) {
+    if (typeof filter !== 'function') {
+      throw new Error('Filter must be a function.')
+    }
+
+    if (typeof handler !== 'function') {
+      throw new Error('Handler must be a function.')
+    }
+
+    await this.#init()
+
+    const id = this.#callbackId++
+
+    this.#requestHandlers.set(id, { filter, handler })
+    return id
+  }
+
+  async removeRequestHandler(id) {
+    await this.#init()
+
+    if (this.#requestHandlers.has(id)) {
+      this.#requestHandlers.delete(id)
+    } else {
+      throw Error(`Callback with id ${id} not found`)
+    }
+  }
+
+  async clearRequestHandlers() {
+    this.#requestHandlers.clear()
+  }
+
+  /**
+   * Adds a response handler that mocks responses.
+   * @param {Function} filter - A function that takes an HTTP request, returning a boolean.
+   * @param {Function} handler - A function that returns a mocked HTTP response.
+   * @returns {number} - A unique handler ID.
+   * @throws {Error} - If filter is not a function or handler is not an async function.
+   */
+  async addResponseHandler(filter, handler) {
+    if (typeof filter !== 'function') {
+      throw new Error('Filter must be a function.')
+    }
+
+    if (typeof handler !== 'function') {
+      throw new Error('Handler must be a function.')
+    }
+
+    await this.#init()
+
+    const id = this.#callbackId++
+
+    this.#responseHandlers.set(id, { filter, handler })
+    return id
+  }
+
+  async removeResponseHandler(id) {
+    await this.#init()
+
+    if (this.#responseHandlers.has(id)) {
+      this.#responseHandlers.delete(id)
+    } else {
+      throw Error(`Callback with id ${id} not found`)
+    }
+  }
+
+  async clearResponseHandlers() {
+    this.#responseHandlers.clear()
+  }
 }
 
 module.exports = Network
diff --git a/javascript/selenium-webdriver/lib/test/fileserver.js b/javascript/selenium-webdriver/lib/test/fileserver.js
index 7023cd8d6fe3b..ce280043d0a5e 100644
--- a/javascript/selenium-webdriver/lib/test/fileserver.js
+++ b/javascript/selenium-webdriver/lib/test/fileserver.js
@@ -45,6 +45,7 @@ const Pages = (function () {
     })
   }
 
+  addPage('addRequestBody', 'addRequestBody')
   addPage('ajaxyPage', 'ajaxy_page.html')
   addPage('alertsPage', 'alerts.html')
   addPage('basicAuth', 'basicAuth')
@@ -131,6 +132,7 @@ const Path = {
   PAGE: WEB_ROOT + '/page',
   SLEEP: WEB_ROOT + '/sleep',
   UPLOAD: WEB_ROOT + '/upload',
+  ADD_REQUEST_BODY: WEB_ROOT + '/addRequestBody',
 }
 
 var app = express()
@@ -143,6 +145,7 @@ app
   })
   .use(JS_ROOT, serveIndex(jsDirectory), express.static(jsDirectory))
   .post(Path.UPLOAD, handleUpload)
+  .post(Path.ADD_REQUEST_BODY, addRequestBody)
   .use(WEB_ROOT, serveIndex(baseDirectory), express.static(baseDirectory))
   .use(DATA_ROOT, serveIndex(dataDirectory), express.static(dataDirectory))
   .get(Path.ECHO, sendEcho)
@@ -187,6 +190,32 @@ function sendInifinitePage(request, response) {
   response.end(body)
 }
 
+function addRequestBody(request, response) {
+  let requestBody = ''
+
+  request.on('data', (chunk) => {
+    requestBody += chunk
+  })
+
+  request.on('end', () => {
+    let body = [
+      '<!DOCTYPE html>',
+      '<html>',
+      '<head><title>Page</title></head>',
+      '<body>',
+      `<p>Request Body:</p><pre>${requestBody}</pre>`,
+      '</body>',
+      '</html>',
+    ].join('')
+
+    response.writeHead(200, {
+      'Content-Length': Buffer.byteLength(body, 'utf8'),
+      'Content-Type': 'text/html; charset=utf-8',
+    })
+    response.end(body)
+  })
+}
+
 function sendBasicAuth(request, response) {
   const denyAccess = function () {
     response.writeHead(401, { 'WWW-Authenticate': 'Basic realm="test"' })
diff --git a/javascript/selenium-webdriver/test/lib/webdriver_network_test.js b/javascript/selenium-webdriver/test/lib/webdriver_network_test.js
index ff1ce496bf038..9c5a4c4669758 100644
--- a/javascript/selenium-webdriver/test/lib/webdriver_network_test.js
+++ b/javascript/selenium-webdriver/test/lib/webdriver_network_test.js
@@ -22,6 +22,8 @@ const { Browser } = require('selenium-webdriver')
 const { Pages, suite } = require('../../lib/test')
 const until = require('selenium-webdriver/lib/until')
 const { By } = require('selenium-webdriver')
+const { Request, Response } = require('../../http')
+const { Network } = require('../../bidi/network')
 
 suite(
   function (env) {
@@ -112,6 +114,206 @@ suite(
           assert.strictEqual(e.name, 'UnexpectedAlertOpenError')
         }
       })
+
+      it('can add request handler to modify method', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Request('HEAD', Pages.logEntryAdded, null)
+
+        await driver.network().addRequestHandler(filter, handler)
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('log entry added events'), false)
+      })
+
+      it('can add request handler to modify uri', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Request('GET', Pages.blankPage, null)
+
+        await driver.network().addRequestHandler(filter, handler)
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('blank'), true)
+      })
+
+      it('can add request handler to modify body', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Request('POST', Pages.addRequestBody, 'hello world!')
+
+        await driver.network().addRequestHandler(filter, handler)
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('hello world'), true)
+      })
+
+      it('can add multiple request handlers', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Request('GET', Pages.blankPage, null)
+
+        await driver.network().addRequestHandler(filter, handler)
+
+        await driver.network().addRequestHandler(
+          (req) => req.path.includes('hello.html'),
+          () => new Request('GET', Pages.logEntryAdded, null),
+        )
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('blank'), true)
+      })
+
+      it('can add multiple request handlers with same filter', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Request('GET', Pages.blankPage, null)
+
+        await driver.network().addRequestHandler(filter, handler)
+
+        await driver.network().addRequestHandler(filter, handler)
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('blank'), true)
+      })
+
+      it('can remove request handler', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Request('GET', Pages.blankPage, null)
+
+        const id = await driver.network().addRequestHandler(filter, handler)
+
+        await driver.network().removeRequestHandler(id)
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('entry added'), true)
+      })
+
+      it('throws an error when removing request handler that does not exist', async function () {
+        try {
+          await driver.network().removeRequestHandler(10)
+          assert.fail('Expected error not thrown. Non-existent handler cannot be removed')
+        } catch (e) {
+          assert.strictEqual(e.message, 'Callback with id 10 not found')
+        }
+      })
+
+      it('can clear request handlers', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Request('GET', Pages.blankPage, null)
+
+        await driver.network().addRequestHandler(filter, handler)
+
+        await driver.network().addRequestHandler(
+          (req) => req.path.includes('hello.html'),
+          () => new Request('GET', Pages.logEntryAdded, null),
+        )
+
+        await driver.network().clearRequestHandlers()
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('entry added'), true)
+      })
+
+      it('can add response handler to mock complete response', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Response(500, { test: 'header-value' }, 'Internal server error')
+
+        const network = await Network(driver)
+
+        let onResponseCompleted = null
+
+        await network.responseStarted(function (event) {
+          if (event.response.url.includes('logEntryAdded')) {
+            onResponseCompleted = event.response
+          }
+        })
+
+        await driver.network().addResponseHandler(filter, handler)
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('Internal server error'), true)
+
+        assert.strictEqual(onResponseCompleted.status, 500)
+        assert.strictEqual(onResponseCompleted.headers[0].name, 'test')
+        assert.strictEqual(onResponseCompleted.headers[0].value.value, 'header-value')
+      })
+
+      it('can add multiple response handler with same filter', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Response(500, { test: 'header-value' }, 'Internal server error')
+
+        const network = await Network(driver)
+
+        let onResponseCompleted = null
+
+        await network.responseStarted(function (event) {
+          if (event.response.url.includes('logEntryAdded')) {
+            onResponseCompleted = event.response
+          }
+        })
+
+        await driver.network().addResponseHandler(filter, handler)
+        await driver.network().addResponseHandler(filter, handler)
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('Internal server error'), true)
+
+        assert.strictEqual(onResponseCompleted.status, 500)
+        assert.strictEqual(onResponseCompleted.headers[0].name, 'test')
+        assert.strictEqual(onResponseCompleted.headers[0].value.value, 'header-value')
+      })
+
+      it('throws an error when removing response handler that does not exist', async function () {
+        try {
+          await driver.network().removeResponseHandler(10)
+          assert.fail('Expected error not thrown. Non-existent handler cannot be removed')
+        } catch (e) {
+          assert.strictEqual(e.message, 'Callback with id 10 not found')
+        }
+      })
+
+      it('can clear response handlers', async function () {
+        const filter = (req) => req.path.includes('bidi/logEntryAdded.html')
+        const handler = () => new Response(200, { test: 'header' }, 'Hello!')
+
+        await driver.network().addResponseHandler(filter, handler)
+
+        await driver.network().addResponseHandler(
+          (req) => req.path.includes('hello.html'),
+          () => new Response(401, { test: 'header' }, 'Not found!'),
+        )
+
+        await driver.network().clearResponseHandlers()
+
+        await driver.get(Pages.logEntryAdded)
+
+        const pageSource = await driver.getPageSource()
+
+        assert.strictEqual(pageSource.includes('entry added'), true)
+      })
     })
   },
   { browsers: [Browser.FIREFOX] },

From be0e52f65a1cb13f53f5f430b88178167ba51994 Mon Sep 17 00:00:00 2001
From: Puja Jagani <puja.jagani93@gmail.com>
Date: Mon, 7 Apr 2025 12:45:34 +0530
Subject: [PATCH 2/3] Fixed import

---
 .../selenium-webdriver/test/lib/webdriver_network_test.js       | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/javascript/selenium-webdriver/test/lib/webdriver_network_test.js b/javascript/selenium-webdriver/test/lib/webdriver_network_test.js
index 9c5a4c4669758..5b8c869f41e62 100644
--- a/javascript/selenium-webdriver/test/lib/webdriver_network_test.js
+++ b/javascript/selenium-webdriver/test/lib/webdriver_network_test.js
@@ -22,7 +22,7 @@ const { Browser } = require('selenium-webdriver')
 const { Pages, suite } = require('../../lib/test')
 const until = require('selenium-webdriver/lib/until')
 const { By } = require('selenium-webdriver')
-const { Request, Response } = require('../../http')
+const { Request, Response } = require('selenium-webdriver/http')
 const { Network } = require('../../bidi/network')
 
 suite(

From 515072c026353af82ed78c4e7ce8f12f9c900353 Mon Sep 17 00:00:00 2001
From: Puja Jagani <puja.jagani93@gmail.com>
Date: Mon, 7 Apr 2025 13:07:08 +0530
Subject: [PATCH 3/3] Fix import

---
 .../selenium-webdriver/test/lib/webdriver_network_test.js       | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/javascript/selenium-webdriver/test/lib/webdriver_network_test.js b/javascript/selenium-webdriver/test/lib/webdriver_network_test.js
index 5b8c869f41e62..2d9a9ad682bf3 100644
--- a/javascript/selenium-webdriver/test/lib/webdriver_network_test.js
+++ b/javascript/selenium-webdriver/test/lib/webdriver_network_test.js
@@ -23,7 +23,7 @@ const { Pages, suite } = require('../../lib/test')
 const until = require('selenium-webdriver/lib/until')
 const { By } = require('selenium-webdriver')
 const { Request, Response } = require('selenium-webdriver/http')
-const { Network } = require('../../bidi/network')
+const { Network } = require('selenium-webdriver/bidi/network')
 
 suite(
   function (env) {