File size: 5,648 Bytes
3e274d5
 
 
 
 
 
 
76166e3
 
3e274d5
 
 
 
 
 
 
 
 
 
 
 
76166e3
3e274d5
 
76166e3
3e274d5
 
 
 
 
 
 
76166e3
 
3e274d5
 
 
 
76166e3
3e274d5
 
76166e3
3e274d5
76166e3
3e274d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76166e3
3e274d5
 
 
76166e3
3e274d5
 
 
 
 
 
 
 
76166e3
3e274d5
 
 
 
76166e3
3e274d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from typing import Dict, List, Optional, Any
from crewai import Agent, Task
import logging
from utils.log_manager import LogManager
from pydantic import Field, BaseModel, ConfigDict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class BaseWellnessAgent(Agent):
    """Base agent class with Mistral LLM support"""
    
    # Allow arbitrary types in model
    model_config = ConfigDict(arbitrary_types_allowed=True)
    
    # Define fields that will be used
    log_manager: LogManager = Field(default_factory=LogManager)
    logger: logging.Logger = Field(default=None)
    config: Dict = Field(default_factory=dict)
    model: Any = Field(default=None)
    tokenizer: Any = Field(default=None)
    agent_type: str = Field(default="base")
    
    def __init__(self, model_config: Dict, agent_type: str, **kwargs):
        # Initialize the CrewAI agent first with required fields
        super().__init__(
            role=kwargs.get("role", "Wellness Support Agent"),
            goal=kwargs.get("goal", "Support mental wellness"),
            backstory=kwargs.get("backstory", "I am an AI agent specialized in mental health support."),
            verbose=kwargs.get("verbose", True),
            allow_delegation=kwargs.get("allow_delegation", False),
            tools=kwargs.get("tools", []),
            **kwargs
        )
        
        # Initialize logging and configuration
        self.config = model_config
        self.agent_type = agent_type
        self.logger = self.log_manager.get_agent_logger(agent_type)
        
        # Initialize Mistral model
        self._initialize_model()
        
        self.logger.info(f"{agent_type.capitalize()} Agent initialized")
        
    def _initialize_model(self):
        """Initialize the Mistral model"""
        try:
            model_config = self.config[self.agent_type]
            self.logger.info(f"Initializing Mistral model: {model_config['model_id']}")
            
            # Initialize tokenizer and model
            self.tokenizer = AutoTokenizer.from_pretrained(model_config["model_id"])
            self.model = AutoModelForCausalLM.from_pretrained(
                model_config["model_id"],
                torch_dtype=torch.float32,
                device_map="auto",
                load_in_4bit=True
            )
            
            self.logger.info("Mistral model initialized successfully")
            
        except Exception as e:
            self.logger.error(f"Error initializing Mistral model: {str(e)}")
            raise
            
    def _generate_response(self, input_text: str) -> str:
        """Generate response using Mistral model"""
        try:
            # Prepare input with instruction template
            template = self.config[self.agent_type]["instruction_template"]
            prompt = template.format(input=input_text)
            
            # Tokenize input
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            
            # Generate response
            outputs = self.model.generate(
                **inputs,
                max_length=self.config[self.agent_type].get("max_length", 4096),
                temperature=self.config[self.agent_type].get("temperature", 0.7),
                top_p=self.config[self.agent_type].get("top_p", 0.95),
                repetition_penalty=self.config[self.agent_type].get("repetition_penalty", 1.1),
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
            
            # Decode and clean response
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = response.replace(prompt, "").strip()
            
            return response
            
        except Exception as e:
            self.logger.error(f"Error generating response: {str(e)}")
            return "I apologize, but I encountered an error generating a response."
        
    def execute_task(self, task: Task) -> str:
        """Execute a task assigned to the agent"""
        self.logger.info(f"Executing task: {task.description}")
        
        try:
            # Process the task description as a message
            result = self.process_message(task.description)
            return result["message"]
            
        except Exception as e:
            self.logger.error(f"Error executing task: {str(e)}")
            return "I apologize, but I encountered an error processing your request."
        
    def process_message(self, message: str, context: Dict = None) -> Dict:
        """Process a message and return a response"""
        self.logger.info("Processing message")
        context = context or {}
        
        try:
            # Generate response using Mistral
            response = self._generate_response(message)
            
            return {
                "message": response,
                "agent_type": self.agent_type,
                "task_type": "dialogue"
            }
            
        except Exception as e:
            self.logger.error(f"Error processing message: {str(e)}")
            return {
                "message": "I apologize, but I encountered an error. Let me try a different approach.",
                "agent_type": self.agent_type,
                "task_type": "error_recovery"
            }
            
    def get_status(self) -> Dict:
        """Get the current status of the agent"""
        return {
            "type": self.agent_type,
            "ready": bool(self.model and self.tokenizer),
            "tools_available": len(self.tools)
        }