Skip to content

Commit

Permalink
add socket implementation to RAI dashboard object detection and quest…
Browse files Browse the repository at this point in the history
…ion answering metrics calls
  • Loading branch information
imatiach-msft committed Oct 25, 2023
1 parent 3a5541d commit 0d860a0
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 6 deletions.
10 changes: 5 additions & 5 deletions apps/widget/src/app/ModelAssessment.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
import { ModelAssessmentDashboard } from "@responsible-ai/model-assessment";
import React from "react";

import { callFlaskService } from "./callFlaskService";
import { callFlaskService, connectToFlaskService } from "./callFlaskService";
import { CallbackType, IModelAssessmentProps } from "./ModelAssessmentUtils";

export class ModelAssessment extends React.Component<IModelAssessmentProps> {
Expand All @@ -33,7 +33,7 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
objectDetectionCache: Map<string, [number, number, number]>,
abortSignal: AbortSignal
): Promise<any[]> => {
return callFlaskService(
return connectToFlaskService(
this.props.config,
[
selectionIndexes,
Expand All @@ -42,7 +42,7 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
iouThreshold,
objectDetectionCache
],
"/get_object_detection_metrics",
"handle_object_detection_json",
abortSignal
);
};
Expand All @@ -57,10 +57,10 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
>,
abortSignal: AbortSignal
): Promise<any[]> => {
return callFlaskService(
return connectToFlaskService(
this.props.config,
[selectionIndexes, questionAnsweringCache],
"/get_question_answering_metrics",
"handle_question_answering_json",
abortSignal
);
};
Expand Down
37 changes: 37 additions & 0 deletions apps/widget/src/app/callFlaskService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
// Licensed under the MIT License.

import json5 from "json5";
import { io } from "socket.io-client";

import { IAppConfig } from "./config";

interface IDataResponse<TResponse> {
data: TResponse;
}

export async function callFlaskService<TRequest, TResponse>(
config: Pick<IAppConfig, "baseUrl" | "withCredentials">,
data: TRequest,
Expand All @@ -31,3 +36,35 @@ export async function callFlaskService<TRequest, TResponse>(
return Promise.reject(new Error(resp.statusText));
});
}

export async function connectToFlaskService<TRequest, TResponse>(
config: Pick<IAppConfig, "baseUrl" | "withCredentials">,
data: TRequest,
urlPath: string,
abortSignal?: AbortSignal
): Promise<TResponse> {
return new Promise<TResponse>((resolve, reject) => {
if (abortSignal?.aborted) {
return reject(new Error("Aborted socket connection"));
}
const url = config.baseUrl;
const socket = io(url, {
reconnectionDelayMax: 10000
});
socket.on("connect", () => {
console.log(`socket connected, socket id: ${socket.id}, url: ${url}`);
});
socket.on("disconnect", () => {
console.log("socket disconnected");
});
socket.emit(
urlPath,
{
data: JSON.stringify(data)
},
(response: IDataResponse<TResponse>) => {
return resolve(response.data);
}
);
});
}
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"react-plotly.js": "^2.5.0",
"react-router-dom": "^5.0.1",
"regenerator-runtime": "0.13.7",
"socket.io-client": "^4.7.2",
"tslib": "^2.5.0",
"uuid": "^8.3.0",
"vott-ct": "^2.4.2-rc.0"
Expand Down
11 changes: 10 additions & 1 deletion rai_core_flask/rai_core_flask/flask_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import threading
import time
import uuid
import warnings

from flask import Flask
from flask_socketio import SocketIO
from gevent.pywsgi import WSGIServer

from .environment_detector import build_environment
Expand All @@ -31,6 +33,7 @@ def __init__(self, ip=None, port=None, with_credentials=False):
self.with_credentials = with_credentials
# dictionary to store arbitrary state for use by consuming classes
self.shared_state = {}
self.socketio = SocketIO(self.app)
if self.ip is None:
self.ip = "localhost"
if self.port is None:
Expand Down Expand Up @@ -106,8 +109,14 @@ def run(self):
logger.setLevel(logging.ERROR)
self.server = WSGIServer((ip, self.port), self.app, log=logger)
self.app.config["server"] = self.server
self.socketio.run(self.app, host=ip, port=self.port)
self.server.serve_forever()

def stop(self):
if (self.server.started):
self.server.stop()
try:
self.server.stop()
except Exception as e:
warning_msg = "Caught exceptions when closing server: {}"
warning_msg = warning_msg.format(e)
warnings.warn(warning_msg, UserWarning)
1 change: 1 addition & 0 deletions rai_core_flask/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Flask
Flask-Cors
flask-socketio
ipython<=7.16.3; python_version <= '3.6'
ipython>=7.31.1; python_version > '3.6'
itsdangerous>=2.0.1
Expand Down
12 changes: 12 additions & 0 deletions raiwidgets/raiwidgets/responsibleai_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

"""Defines the Model Analysis Dashboard class."""

import json

from flask import jsonify, request

from raiutils.models import ModelTask
Expand Down Expand Up @@ -114,3 +116,13 @@ def get_question_answering_metrics():
'/get_question_answering_metrics',
methods=["POST"]
)

@self._service.socketio.on('handle_object_detection_json')
def handle_object_detection_json(od_json):
od_data = json.loads(od_json['data'])
return self.input.get_object_detection_metrics(od_data)

@self._service.socketio.on('handle_question_answering_json')
def handle_question_answering_json(qa_json):
qa_data = json.loads(qa_json['data'])
return self.input.get_question_answering_metrics(qa_data)
2 changes: 2 additions & 0 deletions raiwidgets/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ fairlearn==0.7.0
ml-wrappers>=0.4.0
sktime
pmdarima
# temporary fix for gevent error until rai-core-flask updated
shap<0.41.0

# Jupyter dependency that fails with python 3.6
pywinpty==2.0.2; python_version <= '3.6' and sys_platform == 'win32'
Expand Down

0 comments on commit 0d860a0

Please sign in to comment.