pszemraj's picture
✨ save parameters
03e9034
raw
history blame
4.1 kB
"""
utils.py - Utility functions for the project.
"""
import re
import subprocess
from datetime import datetime
from pathlib import Path
import torch
from natsort import natsorted
def validate_pytorch2(torch_version: str = None):
torch_version = torch.__version__ if torch_version is None else torch_version
pattern = r"^2\.\d+(\.\d+)*"
return True if re.match(pattern, torch_version) else False
def get_timestamp() -> str:
"""
get_timestamp - get a timestamp for the current time
Returns:
str, the timestamp
"""
return datetime.now().strftime("%Y%m%d_%H%M%S")
def truncate_word_count(text, max_words=512):
"""
truncate_word_count - a helper function for the gradio module
Parameters
----------
text : str, required, the text to be processed
max_words : int, optional, the maximum number of words, default=512
Returns
-------
dict, the text and whether it was truncated
"""
# split on whitespace with regex
words = re.split(r"\s+", text)
processed = {}
if len(words) > max_words:
processed["was_truncated"] = True
processed["truncated_text"] = " ".join(words[:max_words])
else:
processed["was_truncated"] = False
processed["truncated_text"] = text
return processed
def load_examples(src, filetypes=[".txt", ".pdf"]):
"""
load_examples - a helper function for the gradio module to load examples
Returns:
list of str, the examples
"""
src = Path(src)
src.mkdir(exist_ok=True)
pdf_url = (
"https://www.dropbox.com/s/y92xy7o5qb88yij/all_you_need_is_attention.pdf?dl=1"
)
subprocess.run(["wget", pdf_url, "-O", src / "all_you_need_is_attention.pdf"])
examples = [f for f in src.iterdir() if f.suffix in filetypes]
examples = natsorted(examples)
# load the examples into a list
text_examples = []
for example in examples:
with open(example, "r") as f:
text = f.read()
text_examples.append([text, "base", 2, 1024, 0.7, 3.5, 3])
return text_examples
def load_example_filenames(example_path: str or Path):
"""
load_example_filenames - a helper function for the gradio module to load examples
Returns:
dict, the examples (filename:full path)
"""
example_path = Path(example_path)
# load the examples into a list
examples = {f.name: f for f in example_path.glob("*.txt")}
return examples
def saves_summary(
summarize_output, outpath: str or Path = None, add_signature=True, **kwargs
):
"""
saves_summary - save the summary generated from summarize_via_tokenbatches() to a text file
summarize_output: output from summarize_via_tokenbatches()
outpath: path to the output file
add_signature: whether to add a signature to the output file
kwargs: additional keyword arguments to include in the output file
"""
outpath = (
Path.cwd() / f"document_summary_{get_timestamp()}.txt"
if outpath is None
else Path(outpath)
)
sum_text = [s["summary"][0] for s in summarize_output]
sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in summarize_output]
scores_text = "\n".join(sum_scores)
full_summary = "\n\t".join(sum_text)
with open(
outpath,
"w",
encoding="utf-8",
) as fo:
fo.writelines(full_summary)
fo.write("\n\n")
if add_signature:
fo.write("\n\n---\n\n")
fo.write("Generated with the Document Summarization space :)\n\n")
fo.write("https://hf.co/spaces/pszemraj/document-summarization\n\n")
with open(
outpath,
"a",
) as fo:
fo.write("\n" * 3)
fo.write(f"## Section Scores:\n\n")
fo.writelines(scores_text)
fo.write("\n\n")
fo.write(f"Date: {get_timestamp()}\n\n")
if kwargs:
fo.write("---\n\n")
fo.write("Parameters:\n\n")
for key, value in kwargs.items():
fo.write(f"{key}: {value}\n")
return outpath