Spaces:
Runtime error
Runtime error
File size: 5,648 Bytes
3e274d5 76166e3 3e274d5 76166e3 3e274d5 76166e3 3e274d5 76166e3 3e274d5 76166e3 3e274d5 76166e3 3e274d5 76166e3 3e274d5 76166e3 3e274d5 76166e3 3e274d5 76166e3 3e274d5 76166e3 3e274d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from typing import Dict, List, Optional, Any
from crewai import Agent, Task
import logging
from utils.log_manager import LogManager
from pydantic import Field, BaseModel, ConfigDict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class BaseWellnessAgent(Agent):
"""Base agent class with Mistral LLM support"""
# Allow arbitrary types in model
model_config = ConfigDict(arbitrary_types_allowed=True)
# Define fields that will be used
log_manager: LogManager = Field(default_factory=LogManager)
logger: logging.Logger = Field(default=None)
config: Dict = Field(default_factory=dict)
model: Any = Field(default=None)
tokenizer: Any = Field(default=None)
agent_type: str = Field(default="base")
def __init__(self, model_config: Dict, agent_type: str, **kwargs):
# Initialize the CrewAI agent first with required fields
super().__init__(
role=kwargs.get("role", "Wellness Support Agent"),
goal=kwargs.get("goal", "Support mental wellness"),
backstory=kwargs.get("backstory", "I am an AI agent specialized in mental health support."),
verbose=kwargs.get("verbose", True),
allow_delegation=kwargs.get("allow_delegation", False),
tools=kwargs.get("tools", []),
**kwargs
)
# Initialize logging and configuration
self.config = model_config
self.agent_type = agent_type
self.logger = self.log_manager.get_agent_logger(agent_type)
# Initialize Mistral model
self._initialize_model()
self.logger.info(f"{agent_type.capitalize()} Agent initialized")
def _initialize_model(self):
"""Initialize the Mistral model"""
try:
model_config = self.config[self.agent_type]
self.logger.info(f"Initializing Mistral model: {model_config['model_id']}")
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_config["model_id"])
self.model = AutoModelForCausalLM.from_pretrained(
model_config["model_id"],
torch_dtype=torch.float32,
device_map="auto",
load_in_4bit=True
)
self.logger.info("Mistral model initialized successfully")
except Exception as e:
self.logger.error(f"Error initializing Mistral model: {str(e)}")
raise
def _generate_response(self, input_text: str) -> str:
"""Generate response using Mistral model"""
try:
# Prepare input with instruction template
template = self.config[self.agent_type]["instruction_template"]
prompt = template.format(input=input_text)
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate response
outputs = self.model.generate(
**inputs,
max_length=self.config[self.agent_type].get("max_length", 4096),
temperature=self.config[self.agent_type].get("temperature", 0.7),
top_p=self.config[self.agent_type].get("top_p", 0.95),
repetition_penalty=self.config[self.agent_type].get("repetition_penalty", 1.1),
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and clean response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "").strip()
return response
except Exception as e:
self.logger.error(f"Error generating response: {str(e)}")
return "I apologize, but I encountered an error generating a response."
def execute_task(self, task: Task) -> str:
"""Execute a task assigned to the agent"""
self.logger.info(f"Executing task: {task.description}")
try:
# Process the task description as a message
result = self.process_message(task.description)
return result["message"]
except Exception as e:
self.logger.error(f"Error executing task: {str(e)}")
return "I apologize, but I encountered an error processing your request."
def process_message(self, message: str, context: Dict = None) -> Dict:
"""Process a message and return a response"""
self.logger.info("Processing message")
context = context or {}
try:
# Generate response using Mistral
response = self._generate_response(message)
return {
"message": response,
"agent_type": self.agent_type,
"task_type": "dialogue"
}
except Exception as e:
self.logger.error(f"Error processing message: {str(e)}")
return {
"message": "I apologize, but I encountered an error. Let me try a different approach.",
"agent_type": self.agent_type,
"task_type": "error_recovery"
}
def get_status(self) -> Dict:
"""Get the current status of the agent"""
return {
"type": self.agent_type,
"ready": bool(self.model and self.tokenizer),
"tools_available": len(self.tools)
} |