richardblythman's picture
Upload folder using huggingface_hub
4ae0b03 verified
import os
import re
import subprocess
import yaml
def create_poetry_package(package_name):
subprocess.run(["poetry", "new", f"tmp/{package_name}"])
def install_poetry_packages(package_name, packages):
os.chdir(f"tmp/{package_name}")
subprocess.run(["poetry", "add", "git+https://github.com/NapthaAI/naptha-sdk.git#feat/single-file"])
for package in packages:
subprocess.run(["poetry", "add", package])
os.chdir( "../..")
def extract_packages(input_code):
lines = input_code.strip().split('\n')
import_pattern = r"\s*from\s+([a-zA-Z_][\w\.]+)\s+import\s+(.*)"
packages = set()
for i, line in enumerate(lines):
# Check if the line starts with 'from' and matches the import pattern
match = re.match(import_pattern, line)
if match:
# Extract the package name from the match
package_name = match.group(1).strip()
if not package_name.startswith('naptha_sdk'):
packages.add(package_name)
return packages
def transform_code_as(input_code):
# Define the new function signature and logger setup
new_header = '''from naptha_sdk.utils import get_logger
logger = get_logger(__name__)
def run(inputs, worker_nodes = None, orchestrator_node = None, flow_run = None, cfg: dict = None):'''
# Split the input code into lines
lines = input_code.strip().split('\n')
def_line_index = 0 # Initialize the index to find where the function definition starts
# Find the index of the line that starts with 'def' or 'async def'
for i, line in enumerate(lines):
stripped_line = line.strip()
if stripped_line.startswith('def ') or stripped_line.startswith('async def'):
def_line_index = i
break
# Remove all lines up to and including the line that contains the 'def'
lines = lines[def_line_index + 1:]
# Remove one tab space from each line
transformed_lines = [line[4:] if line.startswith(' ') else line for line in lines]
# Join the transformed lines with the new header
transformed_code = new_header + '\n' + '\n'.join(transformed_lines)
return transformed_code
def transform_code_mas(input_code):
# Define the new function signature and logger setup
new_header = '''from naptha_sdk.utils import get_logger
from naptha_sdk.agent_service import AgentService
logger = get_logger(__name__)
async def run(inputs, worker_nodes = None, orchestrator_node = None, flow_run = None, cfg: dict = None):'''
# Split the input code into lines
lines = input_code.strip().split('\n')
def_line_index = 0 # Initialize the index to find where the function definition starts
# Find the index of the line that starts with 'def' or 'async def'
for i, line in enumerate(lines):
stripped_line = line.strip()
if stripped_line.startswith('def ') or stripped_line.startswith('async def'):
def_line_index = i
break
# Remove all lines up to and including the line that contains the 'def'
lines = lines[def_line_index + 1:]
# Remove one tab space from each line
transformed_lines = [line[4:] if line.startswith(' ') else line for line in lines]
# Join the transformed lines with the new header
transformed_code = new_header + '\n' + '\n'.join(transformed_lines)
return transformed_code
def generate_component_yaml(module_name, user_id):
component = {
'name': module_name,
'type': module_name,
'author': user_id,
'version': '0.1.0',
'description': module_name,
'license': 'MIT',
'models': {
'default_model_provider': 'ollama',
'ollama': {
'model': 'ollama/phi',
'max_tokens': 1000,
'temperature': 0,
'api_base': 'http://localhost:11434'
}
},
'inputs': {
'system_message': 'You are a helpful AI assistant.',
'save': False,
'location': 'node'
},
'outputs': {
'filename': 'output.txt',
'save': False,
'location': 'node'
},
'implementation': {
'package': {
'entrypoint': 'run.py'
}
}
}
with open(f'tmp/{module_name}/{module_name}/component.yaml', 'w') as file:
yaml.dump(component, file, default_flow_style=False)
def generate_schema(module_name):
schema_code = '''from pydantic import BaseModel
class InputSchema(BaseModel):
prompt: str
'''
with open(f'tmp/{module_name}/{module_name}/schemas.py', 'w') as file:
file.write(schema_code)
def check_hf_repo_exists(hf_api, repo_id: str) -> bool:
try:
# This will raise an exception if the repo doesn't exist
hf_api.repo_info(repo_id)
return True
except Exception:
return False
def publish_hf_package(hf_api, module_name, repo_id, code, user_id):
with open(f'tmp/{module_name}/{module_name}/run.py', 'w') as file:
file.write(code)
generate_schema(module_name)
generate_component_yaml(module_name, user_id)
repo = f"{user_id}/{repo_id}"
if not check_hf_repo_exists(hf_api, repo):
hf_api.create_repo(repo_id=repo_id)
hf_api.upload_folder(
folder_path=f'tmp/{module_name}',
repo_id=repo,
repo_type="model",
)
tags_info = hf_api.list_repo_refs(repo)
desired_tag = "v0.1"
existing_tags = {tag_info.name for tag_info in tags_info.tags} if tags_info.tags else set()
if desired_tag not in existing_tags:
hf_api.create_tag(repo, repo_type="model", tag=desired_tag)
else:
hf_api.delete_tag(repo, tag=desired_tag)
hf_api.create_tag(repo, repo_type="model", tag=desired_tag)