File size: 5,862 Bytes
4ae0b03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
import re
import subprocess
import yaml
def create_poetry_package(package_name):
subprocess.run(["poetry", "new", f"tmp/{package_name}"])
def install_poetry_packages(package_name, packages):
os.chdir(f"tmp/{package_name}")
subprocess.run(["poetry", "add", "git+https://github.com/NapthaAI/naptha-sdk.git#feat/single-file"])
for package in packages:
subprocess.run(["poetry", "add", package])
os.chdir( "../..")
def extract_packages(input_code):
lines = input_code.strip().split('\n')
import_pattern = r"\s*from\s+([a-zA-Z_][\w\.]+)\s+import\s+(.*)"
packages = set()
for i, line in enumerate(lines):
# Check if the line starts with 'from' and matches the import pattern
match = re.match(import_pattern, line)
if match:
# Extract the package name from the match
package_name = match.group(1).strip()
if not package_name.startswith('naptha_sdk'):
packages.add(package_name)
return packages
def transform_code_as(input_code):
# Define the new function signature and logger setup
new_header = '''from naptha_sdk.utils import get_logger
logger = get_logger(__name__)
def run(inputs, worker_nodes = None, orchestrator_node = None, flow_run = None, cfg: dict = None):'''
# Split the input code into lines
lines = input_code.strip().split('\n')
def_line_index = 0 # Initialize the index to find where the function definition starts
# Find the index of the line that starts with 'def' or 'async def'
for i, line in enumerate(lines):
stripped_line = line.strip()
if stripped_line.startswith('def ') or stripped_line.startswith('async def'):
def_line_index = i
break
# Remove all lines up to and including the line that contains the 'def'
lines = lines[def_line_index + 1:]
# Remove one tab space from each line
transformed_lines = [line[4:] if line.startswith(' ') else line for line in lines]
# Join the transformed lines with the new header
transformed_code = new_header + '\n' + '\n'.join(transformed_lines)
return transformed_code
def transform_code_mas(input_code):
# Define the new function signature and logger setup
new_header = '''from naptha_sdk.utils import get_logger
from naptha_sdk.agent_service import AgentService
logger = get_logger(__name__)
async def run(inputs, worker_nodes = None, orchestrator_node = None, flow_run = None, cfg: dict = None):'''
# Split the input code into lines
lines = input_code.strip().split('\n')
def_line_index = 0 # Initialize the index to find where the function definition starts
# Find the index of the line that starts with 'def' or 'async def'
for i, line in enumerate(lines):
stripped_line = line.strip()
if stripped_line.startswith('def ') or stripped_line.startswith('async def'):
def_line_index = i
break
# Remove all lines up to and including the line that contains the 'def'
lines = lines[def_line_index + 1:]
# Remove one tab space from each line
transformed_lines = [line[4:] if line.startswith(' ') else line for line in lines]
# Join the transformed lines with the new header
transformed_code = new_header + '\n' + '\n'.join(transformed_lines)
return transformed_code
def generate_component_yaml(module_name, user_id):
component = {
'name': module_name,
'type': module_name,
'author': user_id,
'version': '0.1.0',
'description': module_name,
'license': 'MIT',
'models': {
'default_model_provider': 'ollama',
'ollama': {
'model': 'ollama/phi',
'max_tokens': 1000,
'temperature': 0,
'api_base': 'http://localhost:11434'
}
},
'inputs': {
'system_message': 'You are a helpful AI assistant.',
'save': False,
'location': 'node'
},
'outputs': {
'filename': 'output.txt',
'save': False,
'location': 'node'
},
'implementation': {
'package': {
'entrypoint': 'run.py'
}
}
}
with open(f'tmp/{module_name}/{module_name}/component.yaml', 'w') as file:
yaml.dump(component, file, default_flow_style=False)
def generate_schema(module_name):
schema_code = '''from pydantic import BaseModel
class InputSchema(BaseModel):
prompt: str
'''
with open(f'tmp/{module_name}/{module_name}/schemas.py', 'w') as file:
file.write(schema_code)
def check_hf_repo_exists(hf_api, repo_id: str) -> bool:
try:
# This will raise an exception if the repo doesn't exist
hf_api.repo_info(repo_id)
return True
except Exception:
return False
def publish_hf_package(hf_api, module_name, repo_id, code, user_id):
with open(f'tmp/{module_name}/{module_name}/run.py', 'w') as file:
file.write(code)
generate_schema(module_name)
generate_component_yaml(module_name, user_id)
repo = f"{user_id}/{repo_id}"
if not check_hf_repo_exists(hf_api, repo):
hf_api.create_repo(repo_id=repo_id)
hf_api.upload_folder(
folder_path=f'tmp/{module_name}',
repo_id=repo,
repo_type="model",
)
tags_info = hf_api.list_repo_refs(repo)
desired_tag = "v0.1"
existing_tags = {tag_info.name for tag_info in tags_info.tags} if tags_info.tags else set()
if desired_tag not in existing_tags:
hf_api.create_tag(repo, repo_type="model", tag=desired_tag)
else:
hf_api.delete_tag(repo, tag=desired_tag)
hf_api.create_tag(repo, repo_type="model", tag=desired_tag) |