|
import importlib |
|
import sys |
|
from typing import Dict |
|
|
|
import ray |
|
|
|
from lagent.schema import AgentMessage |
|
from lagent.utils import load_class_from_string |
|
|
|
|
|
class AsyncAgentRayActor: |
|
|
|
def __init__( |
|
self, |
|
config: Dict, |
|
num_gpus: int, |
|
): |
|
cls_name = config.pop('type') |
|
python_path = config.pop('python_path', None) |
|
cls_name = load_class_from_string(cls_name, python_path) if isinstance( |
|
cls_name, str) else cls_name |
|
AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name) |
|
self.agent_actor = AsyncAgentActor.remote(**config) |
|
|
|
async def __call__(self, *message: AgentMessage, session_id=0, **kwargs): |
|
response = await self.agent_actor.__call__.remote( |
|
*message, session_id=session_id, **kwargs) |
|
return response |
|
|
|
|
|
class AgentRayActor: |
|
|
|
def __init__( |
|
self, |
|
config: Dict, |
|
num_gpus: int, |
|
): |
|
cls_name = config.pop('type') |
|
python_path = config.pop('python_path', None) |
|
cls_name = load_class_from_string(cls_name, python_path) if isinstance( |
|
cls_name, str) else cls_name |
|
AgentActor = ray.remote(num_gpus=num_gpus)(cls_name) |
|
self.agent_actor = AgentActor.remote(**config) |
|
|
|
def __call__(self, *message: AgentMessage, session_id=0, **kwargs): |
|
response = self.agent_actor.__call__.remote( |
|
*message, session_id=session_id, **kwargs) |
|
return ray.get(response) |
|
|