Spaces:
No application file
No application file
File size: 5,499 Bytes
3883c60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import json
import os.path
import shlex
import subprocess
from enum import Enum
from setup_tools.os import is_windows
extension_states = os.path.join('data', 'extensions.json')
ext_folder = os.path.join('extensions')
def git_ready():
cmd = 'git --version'
cmd = cmd if is_windows() else shlex.split(cmd)
result = subprocess.run(cmd, capture_output=True).returncode
return result == 0
class UpdateStatus(Enum):
no_git = -1
unmanaged = 0
updated = 1
outdated = 2
class Extension:
def __init__(self, ext_name, load_states):
self.enabled = (ext_name not in load_states.keys()) or load_states[ext_name]
self.extname = ext_name
# self.abspath = os.path.abspath(os.path.join(ext_folder, ext_name))
self.path = os.path.join(ext_folder, ext_name)
self.main_file = os.path.join(self.path, 'main.py')
self.req_file = os.path.join(self.path, 'requirements.py') # Optional
self.style_file = os.path.join(self.path, 'style.py')
self.js_file = os.path.join(self.path, 'scripts', 'script.js')
self.git_dir = os.path.join(self.path, '.git')
self.update_el = None
extinfo = os.path.join(self.path, 'extension.json')
if os.path.isfile(extinfo):
with open(extinfo, 'r', encoding='utf8') as info_file:
self.info = json.load(info_file)
for k in ['name', 'description', 'author']:
if k not in self.info:
self.info[k] = 'Not provided'
if 'tags' not in self.info:
self.info['tags'] = []
else:
raise FileNotFoundError(f'No extension.json file for {ext_name} extension.')
def activate(self):
if self.enabled and os.path.isfile(self.main_file):
__import__(os.path.splitext(self.main_file)[0].replace(os.path.sep, '.'), fromlist=[''])
def get_style_rules(self):
if self.enabled and os.path.isfile(self.style_file):
__import__(os.path.splitext(self.style_file)[0].replace(os.path.sep, '.'), fromlist=[''])
def get_requirements(self):
if self.enabled and os.path.isfile(self.req_file):
return __import__(os.path.splitext(self.req_file)[0].replace(os.path.sep, '.'), fromlist=['']).requirements()
return []
def get_javascript(self) -> str | bool:
if self.enabled and os.path.isfile(self.js_file):
return self.js_file
return False
def set_enabled(self, new):
self.enabled = new
set_load_states()
try:
import gradio
return gradio.update(value=new)
except:
return new
def check_updates(self) -> UpdateStatus:
if not os.path.isdir(self.git_dir):
return UpdateStatus.unmanaged
command1 = 'git fetch'
command1 = command1 if is_windows() else shlex.split(command1)
command2 = 'git status -uno'
command2 = command2 if is_windows() else shlex.split(command2)
search_string = 'git pull' # Included in message from git if not up to date
neg_search_string = 'Your branch is up to date'
a = subprocess.run(command1, capture_output=True, cwd=self.path)
if a.returncode != 0:
return UpdateStatus.no_git
b = subprocess.run(command2, capture_output=True, cwd=self.path)
if a.returncode != 0:
return UpdateStatus.no_git
out_string = b.stdout.decode()
if search_string in out_string:
return UpdateStatus.outdated
if neg_search_string in out_string:
return UpdateStatus.updated
return UpdateStatus.outdated
def update(self):
if not os.path.isdir(self.git_dir):
return
command = 'git pull'
command = command if is_windows() else shlex.split(command)
output = subprocess.run(command, capture_output=True, cwd=self.path)
if output.returncode != 0:
print(f'Something went wrong during git pull for {self.extname}')
def get_valid_extensions():
return [e for e in os.listdir(ext_folder)
if os.path.isdir(os.path.join(ext_folder, e))
and os.path.isfile(os.path.join(ext_folder, e, 'extension.json'))]
states: dict[str, Extension] = {}
def set_load_states():
s = {k: v.enabled for k, v in zip(states.keys(), states.values())}
json.dump(s, open(extension_states, 'w', encoding='utf8'))
def get_load_states():
if os.path.isfile(extension_states):
return json.load(open(extension_states, 'r', encoding='utf8'))
return {}
register_callbacks = [
'webui.init',
'webui.settings',
'webui.tabs',
'webui.tabs.utils',
'webui.tts.list'
]
def init_extensions():
# Register default callbacks
from webui.extensionlib.callbacks import register_new as register
for cb in register_callbacks:
register(cb)
# Load enabled extensions
s = get_load_states()
exts = get_valid_extensions()
print(f'Found extensions: {", ".join(exts)}')
for ext in exts:
states[ext] = Extension(ext, s)
def get_scripts() -> list[str]:
out = []
for script in [e.get_javascript() for e in states.values()]:
if script:
out.append(script)
return out
def get_requirements():
out = []
for req in [e.get_requirements() for e in states.values()]:
if req:
out += req
return out
|