muryshev's picture
f
1c50aa4
raw
history blame
24.2 kB
import json
import logging
import os
import shutil
import zipfile
from datetime import datetime
from multiprocessing import Process
from pathlib import Path
from typing import Optional
from threading import Lock
import pandas as pd
import torch
from fastapi import BackgroundTasks, HTTPException, UploadFile
from common.common import get_source_format
from common.configuration import Configuration
from components.embedding_extraction import EmbeddingExtractor
from components.parser.features.documents_dataset import DocumentsDataset
from components.parser.pipeline import DatasetCreationPipeline
from components.parser.xml.structures import ParsedXML
from components.parser.xml.xml_parser import XMLParser
from sqlalchemy.orm import Session
from components.dbo.models.acronym import Acronym
from components.dbo.models.dataset import Dataset
from components.dbo.models.dataset_document import DatasetDocument
from components.dbo.models.document import Document
from schemas.dataset import Dataset as DatasetSchema
from schemas.dataset import DatasetExpanded as DatasetExpandedSchema
from schemas.dataset import DatasetProcessing
from schemas.dataset import DocumentsPage as DocumentsPageSchema
from schemas.dataset import SortQueryList
from schemas.document import Document as DocumentSchema
logger = logging.getLogger(__name__)
class DatasetService:
"""
Сервис для работы с датасетами.
"""
def __init__(
self,
vectorizer: EmbeddingExtractor,
config: Configuration,
db: Session
) -> None:
logger.info("DatasetService initializing")
self.db = db
self.config = config
self.parser = XMLParser()
self.vectorizer = vectorizer
self.regulations_path = Path(config.db_config.files.regulations_path)
self.documents_path = Path(config.db_config.files.documents_path)
self.tmp_path= Path(os.environ.get("APP_TMP_PATH", '.'))
logger.info("DatasetService initialized")
def get_dataset(
self,
dataset_id: int,
page: int = 1,
page_size: int = 20,
search: str = '',
sort: SortQueryList = [],
) -> DatasetExpandedSchema:
"""
Получить пагинированную информацию о датасете и его документах.
"""
logger.info(
f"Getting dataset {dataset_id} (page={page}, size={page_size}, search='{search}')"
)
self.raise_if_processing()
with self.db() as session:
dataset: Dataset = (
session.query(Dataset).filter(Dataset.id == dataset_id).first()
)
if not dataset:
raise HTTPException(status_code=404, detail='Dataset not found')
query = (
session.query(Document)
.join(DatasetDocument, DatasetDocument.document_id == Document.id)
.filter(DatasetDocument.dataset_id == dataset_id)
.filter(
Document.status.in_(['Актуальный', 'Требует актуализации', 'Упразднён'])
)
.filter(Document.title.like(f'%{search}%'))
)
query = self.sort_documents(query, sort)
documents = query.offset((page - 1) * page_size).limit(page_size).all()
total_documents = (
session.query(Document)
.join(DatasetDocument, DatasetDocument.document_id == Document.id)
.filter(DatasetDocument.dataset_id == dataset_id)
.filter(
Document.status.in_(['Актуальный', 'Требует актуализации', 'Упразднён'])
)
.filter(Document.title.like(f'%{search}%'))
.count()
)
dataset_expanded = DatasetExpandedSchema(
id=dataset.id,
name=dataset.name,
isDraft=dataset.is_draft,
isActive=dataset.is_active,
dateCreated=dataset.date_created,
data=DocumentsPageSchema(
page=[
DocumentSchema(
id=document.id,
name=document.title,
owner=document.owner,
status=document.status,
)
for document in documents
],
total=total_documents,
pageNumber=page,
pageSize=page_size,
),
)
return dataset_expanded
def get_datasets(self) -> list[DatasetSchema]:
"""
Получить список всех датасетов.
"""
self.raise_if_processing()
with self.db() as session:
datasets: list[Dataset] = session.query(Dataset).all()
return [
DatasetSchema(
id=dataset.id,
name=dataset.name,
isDraft=dataset.is_draft,
isActive=dataset.is_active,
dateCreated=dataset.date_created
)
for dataset in datasets
]
def create_draft(self, parent_id: int) -> DatasetSchema:
"""
Создать черновик датасета на основе родительского датасета.
"""
logger.info(f"Creating draft dataset from parent {parent_id}")
self.raise_if_processing()
with self.db() as session:
parent = session.query(Dataset).filter(Dataset.id == parent_id).first()
if not parent:
raise HTTPException(status_code=404, detail='Parent dataset not found')
if parent.is_draft:
raise HTTPException(status_code=400, detail='Parent dataset is draft')
date = datetime.now()
dataset = Dataset(
name=f"{date.strftime('%Y-%m-%d %H:%M:%S')}",
is_draft=True,
is_active=False,
)
parent_documents = (
session.query(DatasetDocument)
.filter(DatasetDocument.dataset_id == parent_id)
.all()
)
new_dataset_documents = [
DatasetDocument(
dataset_id=dataset.id,
document_id=document.id,
)
for document in parent_documents
]
dataset.documents = new_dataset_documents
session.add(dataset)
session.commit()
session.refresh(dataset)
return self.get_dataset(dataset.id)
def delete_dataset(self, dataset_id: int) -> None:
"""
Удалить черновик датасета.
"""
logger.info(f"Deleting dataset {dataset_id}")
self.raise_if_processing()
with self.db() as session:
dataset: Dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise HTTPException(status_code=404, detail='Dataset not found')
if dataset.name == 'default':
raise HTTPException(
status_code=400, detail='Default dataset cannot be deleted'
)
if dataset.is_active:
raise HTTPException(
status_code=403, detail='Active dataset cannot be deleted'
)
session.delete(dataset)
session.commit()
def apply_draft_task(self, dataset_id: int):
"""
Метод для выполнения в отдельном процессе.
"""
try:
with self.db() as session:
dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise HTTPException(status_code=404, detail=f"Dataset with id {dataset_id} not found")
active_dataset = session.query(Dataset).filter(Dataset.is_active == True).first()
self.apply_draft(dataset, session)
dataset.is_draft = False
dataset.is_active = True
if active_dataset:
active_dataset.is_active = False
session.commit()
except Exception as e:
logger.error(f"Error applying draft: {e}")
raise
def activate_dataset(self, dataset_id: int, background_tasks: BackgroundTasks) -> DatasetExpandedSchema:
"""
Активировать датасет в фоновой задаче.
"""
logger.info(f"Activating dataset {dataset_id}")
self.raise_if_processing()
with self.db() as session:
dataset = (
session.query(Dataset).filter(Dataset.id == dataset_id).first()
)
active_dataset = session.query(Dataset).filter(Dataset.is_active).first()
if not dataset:
raise HTTPException(status_code=404, detail='Dataset not found')
if dataset.is_active:
raise HTTPException(status_code=400, detail='Dataset is already active')
if dataset.is_draft:
background_tasks.add_task(self.apply_draft_task, dataset_id)
else:
dataset.is_active = True
if active_dataset:
active_dataset.is_active = False
session.commit()
return self.get_dataset(dataset_id)
def get_processing(self) -> DatasetProcessing:
"""
Получить информацию о процессе обработки датасета.
"""
tmp_file = Path(self.tmp_path / 'tmp.json')
if tmp_file.exists():
try:
with open(tmp_file, 'r', encoding='utf-8') as f:
info = json.load(f)
except Exception as e:
logger.warning(f"Error loading processing info: {e}")
return DatasetProcessing(
status='in_progress',
total=None,
current=None,
datasetName=None,
)
with self.db() as session:
dataset_name = (
session.query(Dataset)
.filter(Dataset.id == info['dataset_id'])
.first()
.name
)
return DatasetProcessing(
status='in_progress',
total=info['total'],
current=info['current'],
datasetName=dataset_name,
)
return DatasetProcessing(
status='ready',
total=None,
current=None,
datasetName=None,
)
def upload_zip(self, file: UploadFile) -> DatasetExpandedSchema:
"""
Загрузить архив с датасетом.
"""
logger.info(f"Uploading ZIP file {file.filename}")
self.raise_if_processing()
file_location = Path(self.tmp_path / 'tmp.json' / 'tmp.zip')
logger.debug(f"Saving uploaded file to {file_location}")
file_location.parent.mkdir(parents=True, exist_ok=True)
with open(file_location, 'wb') as f:
f.write(file.file.read())
with zipfile.ZipFile(file_location, 'r') as zip_ref:
zip_ref.extractall(file_location.parent)
dataset = self.create_dataset_from_directory(
is_default=False,
directory_with_xmls=file_location.parent,
directory_with_ready_dataset=None,
)
file_location.unlink()
shutil.rmtree(file_location.parent)
return self.get_dataset(dataset.id)
def apply_draft(
self,
dataset: Dataset,
session,
) -> None:
"""
Сохранить черновик как полноценный датасет.
"""
torch.set_num_threads(1)
logger.info(f"Applying draft dataset {dataset.id}")
if not dataset.is_draft:
logger.error(f"Dataset {dataset.id} is not a draft")
raise HTTPException(
status_code=400, detail='Dataset is not draft but trying to apply it'
)
TMP_PATH = Path(self.tmp_path / 'tmp.json')
def progress_callback(current: int, total: int) -> None:
log_step = total // 100
if log_step == 0:
log_step = 1
if current % log_step != 0:
return
if (total > 10) and (current % (total // 10) == 0):
logger.info(
f"Processing dataset {dataset.id}: {current}/{total}"
)
with open(TMP_PATH, 'w', encoding='utf-8') as f:
json.dump(
{
'total': total,
'current': current,
'dataset_id': dataset.id,
},
f,
)
TMP_PATH.touch()
document_ids = [
doc_dataset_link.document_id for doc_dataset_link in dataset.documents
]
document_formats = [
doc_dataset_link.document.source_format
for doc_dataset_link in dataset.documents
]
prepared_abbreviations = (
session.query(Acronym).filter(Acronym.document_id.in_(document_ids)).all()
)
pipeline = DatasetCreationPipeline(
dataset_id=dataset.id,
vectorizer=self.vectorizer,
prepared_abbreviations=prepared_abbreviations,
document_ids=document_ids,
document_formats=document_formats,
datasets_path=self.regulations_path,
documents_path=self.documents_path,
save_intermediate_files=True,
)
progress_callback(0, 1000)
try:
pipeline.run(progress_callback)
except Exception as e:
logger.error(f"Error running pipeline: {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
TMP_PATH.unlink()
def raise_if_processing(self) -> None:
"""
Поднять ошибку, если процесс обработки датасета еще не завершен.
"""
if self.get_processing().status == 'in_progress':
logger.error("Dataset processing is already in progress")
raise HTTPException(
status_code=409, detail='Dataset processing is in progress'
)
def create_dataset_from_directory(
self,
is_default: bool,
directory_with_xmls: Path,
directory_with_ready_dataset: Path | None = None,
) -> Dataset:
"""
Создать датасет из директории с xml-документами.
Args:
is_default: Создать ли датасет по умолчанию.
directory_with_xmls: Путь к директории с xml-документами.
directory_with_processed_dataset: Путь к директории с обработанным датасетом - если не передан, будет произведена полная обработка (например, при создании датасета из скриптов).
Returns:
Dataset: Созданный датасет.
"""
logger.info(
f"Creating {'default' if is_default else 'new'} dataset from directory {directory_with_xmls}"
)
with self.db() as session:
documents = []
date = datetime.now()
name = 'default' if is_default else f'{date.strftime("%Y-%m-%d %H:%M:%S")}'
dataset = Dataset(
name=name,
is_draft=True if directory_with_ready_dataset is None else False,
is_active=True if is_default else False,
)
session.add(dataset)
for subpath in self._get_recursive_dirlist(directory_with_xmls):
document, relation = self._create_document(
directory_with_xmls, subpath, dataset
)
if document is None:
continue
documents.append(document)
session.add(document)
session.add(relation)
logger.info(f"Created {len(documents)} documents")
session.flush()
if directory_with_ready_dataset is not None:
shutil.move(
directory_with_ready_dataset,
self.regulations_path / str(dataset.id),
)
logger.info(
f"Moved ready dataset to {self.regulations_path / str(dataset.id)}"
)
self.documents_path.mkdir(parents=True, exist_ok=True)
for document in documents:
session.refresh(document)
old_filename = document.filename
new_filename = '{}.{}'.format(document.id, document.source_format)
shutil.copy(
directory_with_xmls / old_filename, self.documents_path / new_filename
)
document.filename = new_filename
logger.info(f"Documents renamed with ids")
session.commit()
session.refresh(dataset)
dataset_id = dataset.id
logger.info(f"Dataset {dataset_id} created")
df = self.dataset_to_pandas(dataset_id)
(self.regulations_path / str(dataset_id)).mkdir(parents=True, exist_ok=True)
df.to_csv(
self.regulations_path / str(dataset_id) / 'documents.csv', index=False
)
return dataset
def create_empty_dataset(self, is_default: bool) -> Dataset:
"""
Создать пустой датасет.
"""
with self.db() as session:
name = (
'default'
if is_default
else f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}'
)
dataset = Dataset(
name=name,
is_active=True if is_default else False,
is_draft=False,
)
session.add(dataset)
session.commit()
session.refresh(dataset)
self.documents_path.mkdir(exist_ok=True)
dataset_id = dataset.id
folder = self.regulations_path / str(dataset_id)
folder.mkdir(parents=True, exist_ok=True)
pickle_creator = DocumentsDataset([])
pickle_creator.to_pickle(folder / 'dataset.pkl')
df = self.dataset_to_pandas(dataset_id)
df.to_csv(folder / 'documents.csv', index=False)
return dataset
@staticmethod
def _get_recursive_dirlist(path: Path) -> list[Path]:
"""
Возвращает список всех xml и docx файлов на всех уровнях вложенности.
Args:
path: Путь к директории.
Returns:
list[Path]: Список путей к xml-файлам относительно path.
"""
xml_files = set() #set для отбрасывания неуникальных путей
for ext in ('*.xml', '*.XML', '*.docx', '*.DOCX'):
xml_files.update(path.glob(f'**/{ext}'))
return [p.relative_to(path) for p in xml_files]
def _create_document(
self,
documents_path: Path,
subpath: os.PathLike,
dataset: Dataset,
) -> tuple[Document | None, DatasetDocument | None]:
"""
Создаёт документ в базе данных.
Args:
xmls_path: Путь к директории с xml-документами.
subpath: Путь к xml-документу относительно xmls_path.
dataset: Датасет, к которому относится документ.
Returns:
tuple[Document, DatasetDocument]: Кортеж из документа и его связи с датасетом.
"""
logger.debug(f"Creating document from {subpath}")
try:
source_format = get_source_format(str(subpath))
parsed_xml: ParsedXML | None = self.parser.parse(
documents_path / subpath, include_content=False
)
if not parsed_xml:
logger.warning(f"Failed to parse file: {subpath}")
return None, None
document = Document(
filename=str(subpath),
title=parsed_xml.name,
status=parsed_xml.status,
owner=parsed_xml.owner,
source_format=source_format,
)
relation = DatasetDocument(
document=document,
dataset=dataset,
)
return document, relation
except Exception as e:
logger.error(f"Error creating document from {subpath}: {e}")
return None, None
def dataset_to_pandas(self, dataset_id: int) -> pd.DataFrame:
"""
Преобразовать датасет в pandas DataFrame.
"""
with self.db() as session:
links = (
session.query(DatasetDocument)
.filter(DatasetDocument.dataset_id == dataset_id)
.all()
)
documents = (
session.query(Document)
.filter(Document.id.in_([link.document_id for link in links]))
.all()
)
return pd.DataFrame(
[
{
'id': document.id,
'filename': document.filename,
'title': document.title,
'status': document.status,
'owner': document.owner,
}
for document in documents
],
columns=['id', 'filename', 'title', 'status', 'owner'],
)
def get_current_dataset(self) -> Dataset | None:
with self.db() as session:
print(session)
result = session.query(Dataset).filter(Dataset.is_active == True).first()
return result
def get_default_dataset(self) -> Dataset | None:
with self.db() as session:
result = session.query(Dataset).filter(Dataset.name == 'default').first()
return result
def sort_documents(
self,
query: "Query", # type: ignore
sort: SortQueryList,
) -> "Query": # type: ignore
"""
Сортирует документы по заданным полям и направлениям сортировки.
"""
if sort and (len(sort.sorts) > 0):
for sort_query in sort.sorts:
field = sort_query.field
direction = sort_query.direction
if field == 'name':
column = Document.title
elif field == 'status':
column = Document.status
elif field == 'owner':
column = Document.owner
elif field == 'id':
column = Document.id
else:
raise HTTPException(
status_code=400, detail=f'Invalid sort field: {field}'
)
query = query.order_by(
column.desc() if direction.lower() == 'desc' else column
)
else:
query = query.order_by(Document.id.desc()) # Default sorting
return query