File size: 8,707 Bytes
20e0ec5
187aba8
 
 
 
 
 
692c512
187aba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb918bc
187aba8
 
 
 
 
 
eb918bc
187aba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dcb660
187aba8
 
 
 
 
 
3dcb660
187aba8
 
 
3dcb660
187aba8
 
3dcb660
187aba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4516a
187aba8
 
7d4516a
187aba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4516a
187aba8
 
7d4516a
187aba8
 
 
 
 
 
78e9275
187aba8
 
 
 
 
eb918bc
7d4516a
187aba8
 
 
 
 
 
 
 
 
 
 
7d4516a
187aba8
 
 
60045aa
187aba8
 
 
60045aa
 
 
 
 
 
187aba8
 
 
 
 
 
 
 
 
60045aa
187aba8
 
 
 
 
 
 
 
60045aa
187aba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60045aa
187aba8
 
 
3c3bf42
ecbc10f
187aba8
 
 
 
60045aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import os
import logging
import asyncio
import yaml
from typing import Dict, List, Any, Tuple, Optional
from abc import ABC, abstractmethod
import gradio as gr
from langchain_community.llms import HuggingFaceHub
from dotenv import load_dotenv
from langchain.agents import initialize_agent, AgentType
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

# Load environment variables
load_dotenv()

# Custom Exceptions
class CodeFusionError(Exception):
    """Base exception class for CodeFusion."""
    pass

class AgentInitializationError(CodeFusionError):
    """Raised when there's an error initializing the agent."""
    pass

class ToolExecutionError(CodeFusionError):
    """Raised when there's an error executing a tool."""
    pass

# Utility Functions
def load_config() -> Dict:
    """Load configuration from config.yaml file or use default values."""
    config_path = 'config.yaml'
    default_config = {
        'model_name': "google/flan-t5-xl",
        'api_key': "your_default_api_key_here",
        'temperature': 0.5,
        'verbose': True
    }

    try:
        with open(config_path, 'r') as config_file:
            config = yaml.safe_load(config_file)
    except FileNotFoundError:
        print(f"Config file not found at {config_path}. Using default configuration.")
        config = default_config

    # Override with environment variables if set
    config['api_key'] = os.getenv('HUGGINGFACE_API_KEY', config['api_key'])
    return config

def setup_logging() -> logging.Logger:
    """Set up logging configuration."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        filename='codefusion.log'
    )
    return logging.getLogger(__name__)

# Load configuration and set up logging
config = load_config()
logger = setup_logging()

# Tool Classes
class Tool(ABC):
    """Abstract base class for all tools used by the agent."""

    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description
        self.llm = HuggingFaceHub(
            repo_id=config['model_name'],
            model_kwargs={"temperature": config['temperature']},
            huggingfacehub_api_token=config['api_key']
        )

    @abstractmethod
    async def run(self, arguments: Dict[str, Any]) -> Dict[str, str]:
        """Execute the tool's functionality."""
        pass

class CodeGenerationTool(Tool):
    """Tool for generating code snippets in various languages."""

    def __init__(self):
        super().__init__("Code Generation", "Generates code snippets in various languages.")
        self.prompt_template = PromptTemplate(
            input_variables=["language", "code_description"],
            template="Generate {language} code for: {code_description}"
        )
        self.chain = LLMChain(llm=self.llm, prompt=self.prompt_template)

    async def run(self, arguments: Dict[str, str]) -> Dict[str, str]:
        language = arguments.get("language", "python")
        code_description = arguments.get("code_description", "print('Hello, World!')")
        try:
            code = await self.chain.arun(language=language, code_description=code_description)
            return {"output": code}
        except Exception as e:
            logger.error(f"Error in CodeGenerationTool: {e}")
            raise ToolExecutionError(f"Failed to generate code: {e}")

class CodeExplanationTool(Tool):
    """Tool for explaining code snippets."""

    def __init__(self):
        super().__init__("Code Explanation", "Explains code snippets in simple terms.")
        self.prompt_template = PromptTemplate(
            input_variables=["code"],
            template="Explain the following code in simple terms:\n\n{code}"
        )
        self.chain = LLMChain(llm=self.llm, prompt=self.prompt_template)

    async def run(self, arguments: Dict[str, str]) -> Dict[str, str]:
        code = arguments.get("code", "print('Hello, World!')")
        try:
            explanation = await self.chain.arun(code=code)
            return {"output": explanation}
        except Exception as e:
            logger.error(f"Error in CodeExplanationTool: {e}")
            raise ToolExecutionError(f"Failed to explain code: {e}")

class DebuggingTool(Tool):
    """Tool for debugging code snippets."""

    def __init__(self):
        super().__init__("Debugging", "Helps identify and fix issues in code snippets.")
        self.prompt_template = PromptTemplate(
            input_variables=["code", "error_message"],
            template="Debug the following code:\n\n{code}\n\nError message: {error_message}"
        )
        self.chain = LLMChain(llm=self.llm, prompt=self.prompt_template)

    async def run(self, arguments: Dict[str, str]) -> Dict[str, str]:
        code = arguments.get("code", "")
        error_message = arguments.get("error_message", "")
        try:
            debug_result = await self.chain.arun(code=code, error_message=error_message)
            return {"output": debug_result}
        except Exception as e:
            logger.error(f"Error in DebuggingTool: {e}")
            raise ToolExecutionError(f"Failed to debug code: {e}")

# Agent Class
class Agent:
    """Represents an AI agent with specific tools and capabilities."""

    def __init__(self, name: str, role: str, tools: List[Tool]):
        self.name = name
        self.role = role
        self.tools = tools
        self.memory: List[tuple] = []

        try:
            self.llm = HuggingFaceHub(
                repo_id=config['model_name'],
                model_kwargs={"temperature": config['temperature']},
                huggingfacehub_api_token=config['api_key']
            )
            self.agent = initialize_agent(
                llm=self.llm,
                tools=self.tools,
                agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                verbose=config['verbose']
            )
        except Exception as e:
            logger.error(f"Error initializing agent: {e}")
            raise AgentInitializationError(f"Failed to initialize agent: {e}")

    async def act(self, prompt: str, context: str, mode: str) -> str:
        """Perform an action based on the given prompt and context."""
        self.memory.append((prompt, context))
        try:
            if mode == "full":
                action = await self.agent.arun(prompt, context)
            elif mode == "half":
                action = f"Please follow these instructions: {prompt}. Then, let me know what you did."
            else:  # mode == "none"
                action = "I'm here if you need assistance. Just ask!"
            return action
        except Exception as e:
            logger.error(f"Error during agent action: {e}")
            raise

    def __str__(self) -> str:
        return f"Agent: {self.name} (Role: {self.role})"

# Main application functions
async def run(message: str, history: List[Tuple[str, str]], mode: str) -> str:
    """Process user input and generate a response using the agent system."""
    agent = Agent(
        name="CodeFusion",
        role="AI Coding Assistant",
        tools=[CodeGenerationTool(), CodeExplanationTool(), DebuggingTool()]
    )
    context = "\n".join([f"Human: {h[0]}\nAI: {h[1]}" for h in history])
    try:
        response = await agent.act(message, context, mode)
        return response
    except Exception as e:
        logger.error(f"Error processing request: {e}")
        return "I apologize, but an error occurred while processing your request. Please try again."

async def main():
    """Main function to run the Gradio interface."""
    examples = [
        ["What is the purpose of this AI agent?", "I am an AI coding assistant designed to help with various programming tasks."],
        ["Can you help me generate a Python function to calculate the factorial of a number?", "Certainly! Here's a Python function to calculate the factorial of a number:"],
        ["Explain the concept of recursion in programming.", "Recursion is a programming concept where a function calls itself to solve a problem by breaking it down into smaller, similar subproblems."],
    ]

    gr.ChatInterface(
        fn=run,
        title="CodeFusion: Your AI Coding Assistant",
        description="Ask me about code generation, explanation, debugging, or any other coding task! Choose a mode: full autonomy, 50% engagement, or user-referred assistance.",
        examples=examples,
        theme="default"
    ).launch()

if __name__ == "__main__":
    import sys
    if len(sys.argv) > 1 and sys.argv[1] == "--test":
        run_tests()
    else:
        asyncio.run(main())