Skip to content
Draft
51 changes: 51 additions & 0 deletions examples/multi_model/deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import mii

gpu_index_map1 = {'master': [0]}
gpu_index_map2 = {'master': [1]}
gpu_index_map3 = {'master': [0, 1]}

deployments = []

mii_configs1 = {"tensor_parallel": 2, "dtype": "fp16"}
mii_configs2 = {"tensor_parallel": 1}

name = "bigscience/bloom-560m"
deployments.append({
'task': 'text-generation',
'model': name,
'deployment_name': name + "_deployment",
'GPU_index_map': gpu_index_map3,
'tensor_parallel': 2,
'dtype': "fp16"
})

# gpt2
name = "microsoft/DialogRPT-human-vs-rand"
deployments.append({
'task': 'text-classification',
'model': name,
'deployment_name': name + "_deployment",
'GPU_index_map': gpu_index_map2
})

name = "microsoft/DialoGPT-large"
deployments.append({
'task': 'conversational',
'model': name,
'deployment_name': name + "_deployment",
'GPU_index_map': gpu_index_map1,
})

name = "deepset/roberta-large-squad2"
deployments.append({
'task': "question-answering",
'model': name,
'deployment_name': name + "-qa-deployment",
'GPU_index_map': gpu_index_map2
})

mii.deploy(deployment_tag="multi_models", deployment_configs=deployments[:2])
50 changes: 50 additions & 0 deletions examples/multi_model/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import mii

results = []
generator = mii.mii_query_handle("multi_models")
result = generator.query(
{
"query": ["DeepSpeed is",
"Seattle is"],
"deployment_name": "bigscience/bloom-560m_deployment"
},
do_sample=True,
max_new_tokens=30,
)
results.append(result)
print(result)

result = generator.query({
'query':
"DeepSpeed is the greatest",
"deployment_name":
"microsoft/DialogRPT-human-vs-rand_deployment"
})
results.append(result)
print(result)

result = generator.query({
'text': "DeepSpeed is the greatest",
'conversation_id': 3,
'past_user_inputs': [],
'generated_responses': [],
"deployment_name": "microsoft/DialoGPT-large_deployment"
})
results.append(result)
print(result)

result = generator.query({
'question':
"What is the greatest?",
'context':
"DeepSpeed is the greatest",
"deployment_name":
"deepset/roberta-large-squad2" + "-qa-deployment"
})
results.append(result)
print(result)
7 changes: 7 additions & 0 deletions examples/multi_model/shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import mii

mii.terminate("multi_models")
4 changes: 2 additions & 2 deletions mii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from .client import MIIClient, mii_query_handle
from .deployment import deploy
from .terminate import terminate
from .constants import DeploymentType, Tasks
from .constants import DeploymentType, TaskType
from .aml_related.utils import aml_output_path

from .config import MIIConfig, LoadBalancerConfig
from .config import MIIConfig, DeploymentConfig
from .grpc_related.proto import modelresponse_pb2_grpc

__version__ = "0.0.0"
Expand Down
113 changes: 69 additions & 44 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@
import grpc
import requests
import mii
from mii.utils import get_task
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
from mii.constants import GRPC_MAX_MSG_SIZE, Tasks
from mii.constants import GRPC_MAX_MSG_SIZE, TaskType
from mii.method_table import GRPC_METHOD_TABLE
from mii.config import MIIConfig


def _get_deployment_info(deployment_name):
configs = mii.utils.import_score_file(deployment_name).configs
task = configs[mii.constants.TASK_NAME_KEY]
mii_configs_dict = configs[mii.constants.MII_CONFIGS_KEY]
mii_configs = mii.config.MIIConfig(**mii_configs_dict)

assert task is not None, "The task name should be set before calling init"
return task, mii_configs
def _get_mii_config(deployment_name):
mii_config = mii.utils.import_score_file(deployment_name).mii_config
return MIIConfig(**mii_config)


def mii_query_handle(deployment_name):
Expand All @@ -39,40 +34,64 @@ def mii_query_handle(deployment_name):
inference_pipeline, task = mii.non_persistent_models[deployment_name]
return MIINonPersistentClient(task, deployment_name)

task_name, mii_configs = _get_deployment_info(deployment_name)
return MIIClient(task_name, "localhost", mii_configs.port_number)
mii_config = _get_mii_config(deployment_name)
return MIIClient(mii_config, "localhost", mii_config.port_number)


def create_channel(host, port):
return grpc.aio.insecure_channel(f'{host}:{port}',
options=[('grpc.max_send_message_length',
GRPC_MAX_MSG_SIZE),
('grpc.max_receive_message_length',
GRPC_MAX_MSG_SIZE)])


class MIIClient():
return grpc.aio.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_send_message_length",
GRPC_MAX_MSG_SIZE),
("grpc.max_receive_message_length",
GRPC_MAX_MSG_SIZE),
],
)


class MIIClient:
"""
Client to send queries to a single endpoint.
"""
def __init__(self, task_name, host, port):
def __init__(self, mii_config, host, port):
self.asyncio_loop = asyncio.get_event_loop()
channel = create_channel(host, port)
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
self.task = get_task(task_name)

async def _request_async_response(self, request_dict, **query_kwargs):
if self.task not in GRPC_METHOD_TABLE:
raise ValueError(f"unknown task: {self.task}")

task_methods = GRPC_METHOD_TABLE[self.task]
self.mii_config = mii_config

def _get_deployment_task(self, deployment_name=None):
task = None
if deployment_name is None: #mii.terminate() or single model
if deployment_name is None:
assert len(self.deployments) == 1, "Must pass deployment_name to query when using multiple deployments"
deployment = self.mii_config.deployment_configs[0]
deployment_name = getattr(deployment, deployment_name)
task = getattr(deployment, task)
else:
if deployment_name in self.deployments:
deployment = self.mii_config.deployment_configs[deployment_name]
task = getattr(deployment, task)
else:
assert False, f"{deployment_name} not found in list of deployments"
return deployment_name, task

async def _request_async_response(self, request_dict, task, **query_kwargs):
if task not in GRPC_METHOD_TABLE:
raise ValueError(f"unknown task: {task}")

task_methods = GRPC_METHOD_TABLE[task]
proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
proto_response = await getattr(self.stub, task_methods.method)(proto_request)
proto_response = await getattr(self.mr_stub, task_methods.method)(proto_request)
return task_methods.unpack_response_from_proto(proto_response)

def query(self, request_dict, **query_kwargs):
deployment_name = request_dict.get(mii.constants.DEPLOYMENT_NAME_KEY)
deployment_name, task = self._get_deployment_task(deployment_name)
request_dict['deployment_name'] = deployment_name
return self.asyncio_loop.run_until_complete(
self._request_async_response(request_dict,
task,
**query_kwargs))

async def terminate_async(self):
Expand All @@ -87,7 +106,9 @@ async def create_session_async(self, session_id):
modelresponse_pb2.SessionID(session_id=session_id))

def create_session(self, session_id):
assert self.task == Tasks.TEXT_GENERATION, f"Session creation only available for task '{Tasks.TEXT_GENERATION}'."
assert (
self.task == TaskType.TEXT_GENERATION
), f"Session creation only available for task '{TaskType.TEXT_GENERATION}'."
return self.asyncio_loop.run_until_complete(
self.create_session_async(session_id))

Expand All @@ -96,18 +117,20 @@ async def destroy_session_async(self, session_id):
)

def destroy_session(self, session_id):
assert self.task == Tasks.TEXT_GENERATION, f"Session deletion only available for task '{Tasks.TEXT_GENERATION}'."
assert (
self.task == TaskType.TEXT_GENERATION
), f"Session deletion only available for task '{TaskType.TEXT_GENERATION}'."
self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id))


class MIITensorParallelClient():
class MIITensorParallelClient:
"""
Client to send queries to multiple endpoints in parallel.
This is used to call multiple servers deployed for tensor parallelism.
"""
def __init__(self, task_name, host, ports):
self.task = get_task(task_name)
self.clients = [MIIClient(task_name, host, port) for port in ports]
def __init__(self, task, host, ports):
self.task = task
self.clients = [MIIClient(task, host, port) for port in ports]
self.asyncio_loop = asyncio.get_event_loop()

# runs task in parallel and return the result from the first task
Expand Down Expand Up @@ -155,30 +178,32 @@ def destroy_session(self, session_id):
client.destroy_session(session_id)


class MIINonPersistentClient():
class MIINonPersistentClient:
def __init__(self, task, deployment_name):
self.task = task
self.deployment_name = deployment_name

def query(self, request_dict, **query_kwargs):
assert self.deployment_name in mii.non_persistent_models, f"deployment: {self.deployment_name} not found"
assert (
self.deployment_name in mii.non_persistent_models
), f"deployment: {self.deployment_name} not found"
task_methods = GRPC_METHOD_TABLE[self.task]
inference_pipeline = mii.non_persistent_models[self.deployment_name][0]

if self.task == Tasks.QUESTION_ANSWERING:
if 'question' not in request_dict or 'context' not in request_dict:
if self.task == TaskType.QUESTION_ANSWERING:
if "question" not in request_dict or "context" not in request_dict:
raise Exception(
"Question Answering Task requires 'question' and 'context' keys")
args = (request_dict["question"], request_dict["context"])
kwargs = query_kwargs

elif self.task == Tasks.CONVERSATIONAL:
elif self.task == TaskType.CONVERSATIONAL:
conv = task_methods.create_conversation(request_dict, **query_kwargs)
args = (conv, )
kwargs = {}

else:
args = (request_dict['query'], )
args = (request_dict["query"], )
kwargs = query_kwargs

return task_methods.run_inference(inference_pipeline, args, query_kwargs)
Expand All @@ -189,6 +214,6 @@ def terminate(self):


def terminate_restful_gateway(deployment_name):
_, mii_configs = _get_deployment_info(deployment_name)
if mii_configs.enable_restful_api:
requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate")
mii_config = _get_mii_config(deployment_name)
if mii_config.enable_restful_api:
requests.get(f"http://localhost:{mii_config.restful_api_port}/terminate")
Loading