Spaces:
Paused
Paused
import os | |
import re | |
import gc | |
import torch | |
import transformers | |
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter | |
ENV_FILE_PATH = os.path.join(os.getenv("WRITABLE_DIR", "/tmp"), ".env") | |
def remove_markdown(text: str) -> str: | |
# Remove code block format type and the code block itself | |
text = re.sub(r'```[a-zA-Z]*\n', '', text) # Remove the format type line | |
text = re.sub(r'```', '', text) # Remove remaining backticks for code blocks | |
# Remove headers | |
text = re.sub(r'^\s*#+\s+', '', text, flags=re.MULTILINE) | |
# Remove bold and italic | |
text = re.sub(r'\*\*(.*?)\*\*', r'\1', text) | |
text = re.sub(r'__(.*?)__', r'\1', text) | |
text = re.sub(r'\*(.*?)\*', r'\1', text) | |
text = re.sub(r'_(.*?)_', r'\1', text) | |
# Remove strikethrough | |
text = re.sub(r'~~(.*?)~~', r'\1', text) | |
# Remove inline code | |
text = re.sub(r'`(.*?)`', r'\1', text) | |
# Remove links | |
text = re.sub(r'\[(.*?)\]\((.*?)\)', r'\1', text) | |
# Remove images | |
text = re.sub(r'!\[(.*?)\]\((.*?)\)', '', text) | |
# Remove blockquotes | |
text = re.sub(r'^\s*>\s+', '', text, flags=re.MULTILINE) | |
# Remove lists | |
text = re.sub(r'^\s*[\*\+-]\s+', '', text, flags=re.MULTILINE) | |
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) | |
# Remove horizontal lines | |
text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE) | |
# Remove any remaining markdown symbols | |
text = re.sub(r'[*_~`]', '', text) | |
return text.strip() | |
def remove_outer_markdown_block(chunk, _acc={"b":""}): | |
_acc["b"] += chunk | |
p = re.compile(r'```markdown\s*\n(.*?)\n?```', re.DOTALL|re.IGNORECASE) | |
o = [] | |
while True: | |
m = p.search(_acc["b"]) | |
if not m: | |
break | |
s,e = m.span() | |
o.append(_acc["b"][:s]+m.group(1)) | |
_acc["b"] = _acc["b"][e:] | |
if '```markdown' not in _acc["b"].lower(): | |
o.append(_acc["b"]) | |
_acc["b"] = "" | |
return "".join(o) | |
def clear_gpu_memory(): | |
# Clear GPU memory and cache if available | |
if torch.cuda.is_available(): | |
try: | |
print("Starting the GPU memory cleanup process...") | |
# Clear CUDA cache | |
torch.cuda.empty_cache() | |
# Reset all GPU memory | |
device_count = torch.cuda.device_count() | |
print(f"Number of GPUs: {device_count}") | |
for device_id in range(device_count): | |
print(f"Clearing GPU memory and cache for device {device_id}...") | |
# Set current device before operations | |
torch.cuda.set_device(device_id) | |
torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) | |
torch.cuda.empty_cache() | |
# Force clear any allocated tensors | |
torch.cuda.synchronize() | |
torch.cuda.ipc_collect() | |
except Exception as e: | |
raise Exception(f"Error clearing GPU memory and cache: {e}") | |
def clear_memory(): | |
# Delete all tensors and models | |
print("Deleting all tensors and models...") | |
for obj in gc.get_objects(): | |
try: | |
if torch.is_tensor(obj): | |
del obj | |
elif isinstance(obj, transformers.PreTrainedModel) or \ | |
isinstance(obj, transformers.tokenization_utils_base.PreTrainedTokenizerBase) or \ | |
"SentenceTransformer" in str(type(obj)): | |
model_name = "" # Initialize model name | |
if hasattr(obj, "name_or_path"): | |
model_name = obj.name_or_path | |
elif hasattr(obj, "config") and hasattr(obj.config, "_name_or_path"): | |
model_name = obj.config._name_or_path | |
else: | |
model_name = str(type(obj)) # Fallback to type if name is not found | |
print(f"Deleting model: {model_name}") # Log the model name | |
del obj | |
except Exception as e: | |
print(f"Error during deletion: {e}") | |
gc.collect() # Run garbage collection | |
# Function to chunk text | |
def chunk_text(input_text, max_chunk_length=100, overlap=0, context_length=None): | |
# Use context_length if provided, otherwise use max_chunk_length | |
chunk_size = context_length if isinstance(context_length, int) and context_length > 0 else max_chunk_length | |
splitter = RecursiveCharacterTextSplitter( | |
separators=["\n\n", "\n", ". ", " ", ""], | |
chunk_size=chunk_size, | |
chunk_overlap=overlap, | |
length_function=len | |
) | |
chunks = splitter.split_text(input_text) | |
token_splitter = TokenTextSplitter(chunk_size=max_chunk_length, chunk_overlap=overlap) \ | |
if not context_length else None | |
final_chunks = [] | |
span_annotations = [] | |
current_position = 0 | |
for chunk in chunks: | |
# If token_splitter exists, use it. Otherwise, use the chunk as is | |
current_chunks = token_splitter.split_text(chunk) if token_splitter else [chunk] | |
final_chunks.extend(current_chunks) | |
for tc in current_chunks: | |
span_annotations.append((current_position, current_position + len(tc))) | |
current_position += len(tc) | |
return final_chunks, span_annotations | |
# Function to read .env file | |
def read_env(): | |
env_dict = {} | |
if not os.path.exists(ENV_FILE_PATH): | |
return env_dict | |
with open(ENV_FILE_PATH, "r", encoding="utf-8") as f: | |
for line in f: | |
line = line.strip() | |
if not line or line.startswith("#"): | |
continue | |
if "=" in line: | |
var, val = line.split("=", 1) | |
env_dict[var.strip()] = val.strip() | |
return env_dict | |
# Function to update .env file | |
def update_env_vars(new_values: dict): | |
# Overwrite .env file with new values | |
with open(ENV_FILE_PATH, "w", encoding="utf-8") as f: | |
for var, val in new_values.items(): | |
f.write(f"{var}={val}\n") | |
# Function to prepare provider key updates dictionary | |
def prepare_provider_key_updates(provider: str, multiline_keys: str) -> dict: | |
lines = [ln.strip() for ln in multiline_keys.splitlines() if ln.strip()] | |
updates = {} | |
if provider == "openai": | |
for i, key in enumerate(lines, start=1): | |
updates[f"OPENAI_API_KEY_{i}"] = key | |
elif provider == "google": | |
for i, key in enumerate(lines, start=1): | |
updates[f"GOOGLE_API_KEY_{i}"] = key | |
elif provider == "xai": | |
for i, key in enumerate(lines, start=1): | |
updates[f"XAI_API_KEY_{i}"] = key | |
elif provider == "anthropic": | |
for i, key in enumerate(lines, start=1): | |
updates[f"ANTHROPIC_API_KEY_{i}"] = key | |
return updates | |
# Function to prepare proxy list dictionary | |
def prepare_proxy_list_updates(proxy_list: str) -> list: | |
lines = [proxy.strip() for proxy in proxy_list.splitlines() if proxy.strip()] | |
proxies = {} | |
for i, proxy in enumerate(lines, start=1): | |
proxies[f"PROXY_{i}"] = proxy | |
return proxies |