From 0d860a037d0da3a6fe6c0cfe51bca8c957b14a96 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 2 Oct 2023 17:42:28 -0400 Subject: [PATCH] add socket implementation to RAI dashboard object detection and question answering metrics calls --- apps/widget/src/app/ModelAssessment.tsx | 10 ++--- apps/widget/src/app/callFlaskService.ts | 37 +++++++++++++++++++ package.json | 1 + rai_core_flask/rai_core_flask/flask_helper.py | 11 +++++- rai_core_flask/requirements.txt | 1 + .../raiwidgets/responsibleai_dashboard.py | 12 ++++++ raiwidgets/requirements-dev.txt | 2 + 7 files changed, 68 insertions(+), 6 deletions(-) diff --git a/apps/widget/src/app/ModelAssessment.tsx b/apps/widget/src/app/ModelAssessment.tsx index b44b658b83..0c4a54df4f 100644 --- a/apps/widget/src/app/ModelAssessment.tsx +++ b/apps/widget/src/app/ModelAssessment.tsx @@ -14,7 +14,7 @@ import { import { ModelAssessmentDashboard } from "@responsible-ai/model-assessment"; import React from "react"; -import { callFlaskService } from "./callFlaskService"; +import { callFlaskService, connectToFlaskService } from "./callFlaskService"; import { CallbackType, IModelAssessmentProps } from "./ModelAssessmentUtils"; export class ModelAssessment extends React.Component { @@ -33,7 +33,7 @@ export class ModelAssessment extends React.Component { objectDetectionCache: Map, abortSignal: AbortSignal ): Promise => { - return callFlaskService( + return connectToFlaskService( this.props.config, [ selectionIndexes, @@ -42,7 +42,7 @@ export class ModelAssessment extends React.Component { iouThreshold, objectDetectionCache ], - "/get_object_detection_metrics", + "handle_object_detection_json", abortSignal ); }; @@ -57,10 +57,10 @@ export class ModelAssessment extends React.Component { >, abortSignal: AbortSignal ): Promise => { - return callFlaskService( + return connectToFlaskService( this.props.config, [selectionIndexes, questionAnsweringCache], - "/get_question_answering_metrics", + "handle_question_answering_json", abortSignal ); }; diff --git a/apps/widget/src/app/callFlaskService.ts b/apps/widget/src/app/callFlaskService.ts index f3b3016aa5..34d26f22df 100644 --- a/apps/widget/src/app/callFlaskService.ts +++ b/apps/widget/src/app/callFlaskService.ts @@ -2,9 +2,14 @@ // Licensed under the MIT License. import json5 from "json5"; +import { io } from "socket.io-client"; import { IAppConfig } from "./config"; +interface IDataResponse { + data: TResponse; +} + export async function callFlaskService( config: Pick, data: TRequest, @@ -31,3 +36,35 @@ export async function callFlaskService( return Promise.reject(new Error(resp.statusText)); }); } + +export async function connectToFlaskService( + config: Pick, + data: TRequest, + urlPath: string, + abortSignal?: AbortSignal +): Promise { + return new Promise((resolve, reject) => { + if (abortSignal?.aborted) { + return reject(new Error("Aborted socket connection")); + } + const url = config.baseUrl; + const socket = io(url, { + reconnectionDelayMax: 10000 + }); + socket.on("connect", () => { + console.log(`socket connected, socket id: ${socket.id}, url: ${url}`); + }); + socket.on("disconnect", () => { + console.log("socket disconnected"); + }); + socket.emit( + urlPath, + { + data: JSON.stringify(data) + }, + (response: IDataResponse) => { + return resolve(response.data); + } + ); + }); +} diff --git a/package.json b/package.json index 5a2215d2b6..995d276b9d 100644 --- a/package.json +++ b/package.json @@ -70,6 +70,7 @@ "react-plotly.js": "^2.5.0", "react-router-dom": "^5.0.1", "regenerator-runtime": "0.13.7", + "socket.io-client": "^4.7.2", "tslib": "^2.5.0", "uuid": "^8.3.0", "vott-ct": "^2.4.2-rc.0" diff --git a/rai_core_flask/rai_core_flask/flask_helper.py b/rai_core_flask/rai_core_flask/flask_helper.py index 061075325a..389ab7a34a 100644 --- a/rai_core_flask/rai_core_flask/flask_helper.py +++ b/rai_core_flask/rai_core_flask/flask_helper.py @@ -7,8 +7,10 @@ import threading import time import uuid +import warnings from flask import Flask +from flask_socketio import SocketIO from gevent.pywsgi import WSGIServer from .environment_detector import build_environment @@ -31,6 +33,7 @@ def __init__(self, ip=None, port=None, with_credentials=False): self.with_credentials = with_credentials # dictionary to store arbitrary state for use by consuming classes self.shared_state = {} + self.socketio = SocketIO(self.app) if self.ip is None: self.ip = "localhost" if self.port is None: @@ -106,8 +109,14 @@ def run(self): logger.setLevel(logging.ERROR) self.server = WSGIServer((ip, self.port), self.app, log=logger) self.app.config["server"] = self.server + self.socketio.run(self.app, host=ip, port=self.port) self.server.serve_forever() def stop(self): if (self.server.started): - self.server.stop() + try: + self.server.stop() + except Exception as e: + warning_msg = "Caught exceptions when closing server: {}" + warning_msg = warning_msg.format(e) + warnings.warn(warning_msg, UserWarning) diff --git a/rai_core_flask/requirements.txt b/rai_core_flask/requirements.txt index d494402c4c..c564eb7edc 100644 --- a/rai_core_flask/requirements.txt +++ b/rai_core_flask/requirements.txt @@ -1,5 +1,6 @@ Flask Flask-Cors +flask-socketio ipython<=7.16.3; python_version <= '3.6' ipython>=7.31.1; python_version > '3.6' itsdangerous>=2.0.1 diff --git a/raiwidgets/raiwidgets/responsibleai_dashboard.py b/raiwidgets/raiwidgets/responsibleai_dashboard.py index df1ff30436..96676690b4 100644 --- a/raiwidgets/raiwidgets/responsibleai_dashboard.py +++ b/raiwidgets/raiwidgets/responsibleai_dashboard.py @@ -3,6 +3,8 @@ """Defines the Model Analysis Dashboard class.""" +import json + from flask import jsonify, request from raiutils.models import ModelTask @@ -114,3 +116,13 @@ def get_question_answering_metrics(): '/get_question_answering_metrics', methods=["POST"] ) + + @self._service.socketio.on('handle_object_detection_json') + def handle_object_detection_json(od_json): + od_data = json.loads(od_json['data']) + return self.input.get_object_detection_metrics(od_data) + + @self._service.socketio.on('handle_question_answering_json') + def handle_question_answering_json(qa_json): + qa_data = json.loads(qa_json['data']) + return self.input.get_question_answering_metrics(qa_data) diff --git a/raiwidgets/requirements-dev.txt b/raiwidgets/requirements-dev.txt index 908f590820..39bedbfd5b 100644 --- a/raiwidgets/requirements-dev.txt +++ b/raiwidgets/requirements-dev.txt @@ -12,6 +12,8 @@ fairlearn==0.7.0 ml-wrappers>=0.4.0 sktime pmdarima +# temporary fix for gevent error until rai-core-flask updated +shap<0.41.0 # Jupyter dependency that fails with python 3.6 pywinpty==2.0.2; python_version <= '3.6' and sys_platform == 'win32'