from typing import Dict, Any, List, Union, Optional import os class Memory: # TODO Need to fix this to support multiple data sources (e.g. images, pdf, txt, etc.) def __init__(self): self.query: Optional[str] = None self.files: List[Dict[str, str]] = [] self.actions: Dict[str, Dict[str, Any]] = {} self._init_file_types() def set_query(self, query: str) -> None: if not isinstance(query, str): raise TypeError("Query must be a string") self.query = query def _init_file_types(self): self.file_types = { 'image': ['.jpg', '.jpeg', '.png', '.gif', '.bmp'], 'text': ['.txt', '.md'], 'document': ['.pdf', '.doc', '.docx'], 'code': ['.py', '.js', '.java', '.cpp', '.h'], 'data': ['.json', '.csv', '.xml'], 'spreadsheet': ['.xlsx', '.xls'], 'presentation': ['.ppt', '.pptx'], } self.file_type_descriptions = { 'image': "An image file ({ext} format) provided as context for the query", 'text': "A text file ({ext} format) containing additional information related to the query", 'document': "A document ({ext} format) with content relevant to the query", 'code': "A source code file ({ext} format) potentially related to the query", 'data': "A data file ({ext} format) containing structured data pertinent to the query", 'spreadsheet': "A spreadsheet file ({ext} format) with tabular data relevant to the query", 'presentation': "A presentation file ({ext} format) with slides related to the query", } def _get_default_description(self, file_name: str) -> str: _, ext = os.path.splitext(file_name) ext = ext.lower() for file_type, extensions in self.file_types.items(): if ext in extensions: return self.file_type_descriptions[file_type].format(ext=ext[1:]) return f"A file with {ext[1:]} extension, provided as context for the query" def add_file(self, file_name: Union[str, List[str]], description: Union[str, List[str], None] = None) -> None: if isinstance(file_name, str): file_name = [file_name] if description is None: description = [self._get_default_description(fname) for fname in file_name] elif isinstance(description, str): description = [description] if len(file_name) != len(description): raise ValueError("The number of files and descriptions must match.") for fname, desc in zip(file_name, description): self.files.append({ 'file_name': fname, 'description': desc }) def add_action(self, step_count: int, tool_name: str, sub_goal: str, command: str, result: Any) -> None: action = { 'tool_name': tool_name, 'sub_goal': sub_goal, 'command': command, 'result': result, } step_name = f"Action Step {step_count}" self.actions[step_name] = action def get_query(self) -> Optional[str]: return self.query def get_files(self) -> List[Dict[str, str]]: return self.files def get_actions(self) -> Dict[str, Dict[str, Any]]: return self.actions