|
import os |
|
from modules import shared, utils |
|
from pathlib import Path |
|
import requests |
|
import tqdm |
|
import json |
|
|
|
''' |
|
def get_gpu_memory_usage(rank): |
|
return { |
|
'total': round(torch.cuda.get_device_properties(rank).total_memory / (1024**3), 2), |
|
'max': round(torch.cuda.max_memory_allocated(rank) / (1024**3), 2), |
|
'reserved': round(torch.cuda.memory_reserved(rank) / (1024**3), 2), |
|
'allocated': round(torch.cuda.memory_allocated(rank) / (1024**3), 2) |
|
} |
|
''' |
|
|
|
def list_subfoldersByTime(directory): |
|
|
|
if not directory.endswith('/'): |
|
directory += '/' |
|
subfolders = [] |
|
subfolders.append('None') |
|
path = directory |
|
name_list = os.listdir(path) |
|
full_list = [os.path.join(path,i) for i in name_list] |
|
time_sorted_list = sorted(full_list, key=os.path.getmtime,reverse=True) |
|
|
|
for entry in time_sorted_list: |
|
if os.path.isdir(entry): |
|
entry_str = f"{entry}" |
|
full_path = entry_str |
|
entry_str = entry_str.replace('\\','/') |
|
entry_str = entry_str.replace(f"{directory}", "") |
|
subfolders.append(entry_str) |
|
|
|
return subfolders |
|
|
|
def get_available_loras_local(_sortedByTime): |
|
|
|
model_dir = shared.args.lora_dir |
|
subfolders = [] |
|
if _sortedByTime: |
|
subfolders = list_subfoldersByTime(model_dir) |
|
else: |
|
subfolders = utils.get_available_loras() |
|
|
|
return subfolders |
|
|
|
|
|
|
|
|
|
def split_sentences(text: str, cutoff_len: int): |
|
sentences = [] |
|
sentence = '' |
|
delimiters = ['. ', '? ', '! ', '... ', '.\n', '?\n', '!\n','...\n','</s>','<//>'] |
|
abbreviations = ['Mr. ', 'Mrs. ', 'Dr. ', 'Ms. ', 'St. ', 'Prof. ', 'Jr. ', 'Ltd. ', 'Capt. ', 'Col. ', 'Gen. ', 'Ave. ', 'Blvd. ', 'Co. ', 'Corp. ', 'Dept. ', 'Est. ', 'Gov. ', 'Inc. ', 'Ph.D. ', 'Univ. '] |
|
errors = 0 |
|
max_cut = cutoff_len-1 |
|
prev_char = '' |
|
|
|
for char in text: |
|
sentence += char |
|
|
|
|
|
if (any(sentence.endswith(delimiter) for delimiter in delimiters) and |
|
not (prev_char.isupper() and len(sentence) >= 3 and sentence[-3] != ' ') and |
|
not any(sentence.endswith(abbreviation) for abbreviation in abbreviations)): |
|
tokens = shared.tokenizer.encode(sentence) |
|
|
|
if len(tokens) > max_cut: |
|
tokens = tokens[:max_cut] |
|
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True) |
|
errors = errors + 1 |
|
|
|
sentences.append({'text': sentence, 'size': len(tokens)}) |
|
|
|
sentence = '' |
|
|
|
prev_char = char |
|
|
|
if sentence: |
|
tokens = shared.tokenizer.encode(sentence) |
|
if len(tokens) > max_cut: |
|
tokens = tokens[:max_cut] |
|
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True) |
|
errors = errors + 1 |
|
|
|
sentences.append({'text': sentence, 'size': len(tokens)}) |
|
|
|
if errors > 0: |
|
print(f"Trimmed sentences beyond Cutoff Length: {errors}") |
|
|
|
return sentences |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def precise_cut(text: str, overlap: bool, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool): |
|
|
|
EOSX_str = '<//>' |
|
EOS_str = '</s>' |
|
print("Precise raw text slicer: ON") |
|
|
|
cut_string = hard_cut_string.replace('\\n', '\n') |
|
text = text.replace(cut_string, EOSX_str) |
|
sentences = split_sentences(text, cutoff_len) |
|
|
|
print(f"Sentences: {len(sentences)}") |
|
sentencelist = [] |
|
currentSentence = '' |
|
totalLength = 0 |
|
max_cut = cutoff_len-1 |
|
half_cut = cutoff_len//2 |
|
halfcut_length = 0 |
|
|
|
edgeindex = [] |
|
half_index = 0 |
|
|
|
for index, item in enumerate(sentences): |
|
|
|
if halfcut_length+ item['size'] < half_cut: |
|
halfcut_length += item['size'] |
|
half_index = index |
|
else: |
|
edgeindex.append(half_index) |
|
halfcut_length = -2 * max_cut |
|
|
|
|
|
if totalLength + item['size'] < max_cut and not currentSentence.endswith(EOSX_str): |
|
currentSentence += item['text'] |
|
totalLength += item['size'] |
|
else: |
|
|
|
if len(currentSentence.strip()) > min_chars_cut: |
|
sentencelist.append(currentSentence.strip()) |
|
|
|
currentSentence = item['text'] |
|
totalLength = item['size'] |
|
halfcut_length = item['size'] |
|
|
|
if len(currentSentence.strip()) > min_chars_cut: |
|
sentencelist.append(currentSentence.strip()) |
|
|
|
unique_blocks = len(sentencelist) |
|
print(f"Text Blocks: {unique_blocks}") |
|
|
|
|
|
|
|
if overlap: |
|
for edge_idx in edgeindex: |
|
currentSentence = '' |
|
totalLength = 0 |
|
|
|
for item in sentences[edge_idx:]: |
|
if totalLength + item['size'] < max_cut: |
|
currentSentence += item['text'] |
|
totalLength += item['size'] |
|
else: |
|
|
|
if currentSentence.endswith(EOSX_str) and len(currentSentence.strip()) > min_chars_cut: |
|
sentencelist.append(currentSentence.strip()) |
|
|
|
elif EOSX_str not in currentSentence and len(currentSentence.strip()) > min_chars_cut: |
|
sentencelist.append(currentSentence.strip()) |
|
|
|
currentSentence = '' |
|
totalLength = 0 |
|
break |
|
|
|
print(f"+ Overlapping blocks: {len(sentencelist)-unique_blocks}") |
|
|
|
num_EOS = 0 |
|
for i in range(len(sentencelist)): |
|
if eos_to_hc: |
|
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str) |
|
else: |
|
sentencelist[i] = sentencelist[i].replace(EOSX_str, '') |
|
|
|
|
|
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str) |
|
num_EOS += sentencelist[i].count(EOS_str) |
|
|
|
if num_EOS > 0: |
|
print(f"+ EOS count: {num_EOS}") |
|
|
|
|
|
sentencelist = [item for item in sentencelist if item.strip() != "</s>"] |
|
sentencelist = [item for item in sentencelist if item.strip() != ""] |
|
|
|
|
|
if debug_slicer: |
|
|
|
Path('logs').mkdir(exist_ok=True) |
|
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)} |
|
output_file = "logs/sentencelist.json" |
|
with open(output_file, 'w') as f: |
|
json.dump(sentencelist_dict, f,indent=2) |
|
|
|
print("Saved sentencelist.json in logs folder") |
|
|
|
return sentencelist |
|
|
|
|
|
def sliding_block_cut(text: str, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool): |
|
|
|
EOSX_str = '<//>' |
|
EOS_str = '</s>' |
|
print("Mega Block Overlap: ON") |
|
|
|
cut_string = hard_cut_string.replace('\\n', '\n') |
|
text = text.replace(cut_string, EOSX_str) |
|
sentences = split_sentences(text, cutoff_len) |
|
|
|
print(f"Sentences: {len(sentences)}") |
|
sentencelist = [] |
|
|
|
max_cut = cutoff_len-1 |
|
|
|
|
|
advancing_to = 0 |
|
|
|
prev_block_lastsentence = "" |
|
|
|
|
|
for i in range(len(sentences)): |
|
totalLength = 0 |
|
currentSentence = '' |
|
lastsentence = "" |
|
|
|
if i >= advancing_to: |
|
for k in range(i, len(sentences)): |
|
|
|
current_length = sentences[k]['size'] |
|
|
|
if totalLength + current_length <= max_cut and not currentSentence.endswith(EOSX_str): |
|
currentSentence += sentences[k]['text'] |
|
totalLength += current_length |
|
lastsentence = sentences[k]['text'] |
|
else: |
|
if len(currentSentence.strip()) > min_chars_cut: |
|
if prev_block_lastsentence!=lastsentence: |
|
sentencelist.append(currentSentence.strip()) |
|
prev_block_lastsentence = lastsentence |
|
|
|
advancing_to = 0 |
|
if currentSentence.endswith(EOSX_str): |
|
advancing_to = k |
|
|
|
currentSentence = "" |
|
totalLength = 0 |
|
break |
|
|
|
if currentSentence != "": |
|
if len(currentSentence.strip()) > min_chars_cut: |
|
sentencelist.append(currentSentence.strip()) |
|
|
|
unique_blocks = len(sentencelist) |
|
print(f"Text Blocks: {unique_blocks}") |
|
num_EOS = 0 |
|
for i in range(len(sentencelist)): |
|
if eos_to_hc: |
|
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str) |
|
else: |
|
sentencelist[i] = sentencelist[i].replace(EOSX_str, '') |
|
|
|
|
|
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str) |
|
num_EOS += sentencelist[i].count(EOS_str) |
|
|
|
if num_EOS > 0: |
|
print(f"+ EOS count: {num_EOS}") |
|
|
|
|
|
sentencelist = [item for item in sentencelist if item.strip() != "</s>"] |
|
sentencelist = [item for item in sentencelist if item.strip() != ""] |
|
|
|
|
|
if debug_slicer: |
|
|
|
Path('logs').mkdir(exist_ok=True) |
|
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)} |
|
output_file = "logs/sentencelist.json" |
|
with open(output_file, 'w') as f: |
|
json.dump(sentencelist_dict, f,indent=2) |
|
|
|
print("Saved sentencelist.json in logs folder") |
|
|
|
return sentencelist |
|
|
|
|
|
|
|
|
|
def download_file_from_url(url, overwrite, output_dir_in, valid_extensions = {'.txt', '.json'}): |
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session = requests.Session() |
|
headers = {} |
|
mode = 'wb' |
|
filename = url.split('/')[-1] |
|
|
|
output_dir = str(output_dir_in) |
|
|
|
local_filename = os.path.join(output_dir, filename) |
|
|
|
|
|
overw = '' |
|
if os.path.exists(local_filename): |
|
if not overwrite: |
|
yield f"File '{local_filename}' already exists. Aborting." |
|
return |
|
else: |
|
overw = ' [Overwrite existing]' |
|
|
|
filename_lower = filename.lower() |
|
|
|
|
|
file_extension = os.path.splitext(filename_lower)[-1] |
|
|
|
if file_extension not in valid_extensions: |
|
yield f"Invalid file extension: {file_extension}. Only {valid_extensions} files are supported." |
|
return |
|
|
|
with session.get(url, stream=True, headers=headers, timeout=10) as r: |
|
r.raise_for_status() |
|
|
|
|
|
|
|
block_size = 1024 * 4 |
|
with open(local_filename, mode) as f: |
|
count = 0 |
|
for data in r.iter_content(block_size): |
|
f.write(data) |
|
count += len(data) |
|
|
|
yield f"Downloaded: {count} " + overw |
|
|
|
|
|
if os.path.exists(local_filename): |
|
downloaded_size = os.path.getsize(local_filename) |
|
if downloaded_size > 0: |
|
yield f"File '{filename}' downloaded to '{output_dir}' ({downloaded_size} bytes)." |
|
print("File Downloaded") |
|
else: |
|
print("Downloaded file is zero") |
|
yield f"Failed. Downloaded file size is zero)." |
|
else: |
|
print(f"Error: {local_filename} failed to download.") |
|
yield f"Error: {local_filename} failed to download" |
|
|
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
yield f"An error occurred: {e}" |
|
|
|
finally: |
|
|
|
session.close() |
|
|
|
|