Spaces:
Sleeping
Sleeping
import argparse | |
import logging | |
from typing import Optional | |
import numpy as np | |
from sqlalchemy.orm import Session | |
import common.dependencies as DI | |
from common.configuration import Configuration | |
from components.dbo.models.entity import EntityModel | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def analyze_embeddings(embeddings: list[Optional[np.ndarray]]) -> dict: | |
""" | |
Анализ эмбеддингов. | |
Args: | |
embeddings: Список эмбеддингов | |
Returns: | |
dict: Статистика по эмбеддингам | |
""" | |
valid_embeddings = [e for e in embeddings if e is not None] | |
if not valid_embeddings: | |
return { | |
"total": len(embeddings), | |
"valid": 0, | |
"shapes": {}, | |
"mean_norm": None, | |
"std_norm": None | |
} | |
shapes = {} | |
norms = [] | |
for e in valid_embeddings: | |
shape_str = str(e.shape) | |
shapes[shape_str] = shapes.get(shape_str, 0) + 1 | |
norms.append(np.linalg.norm(e)) | |
return { | |
"total": len(embeddings), | |
"valid": len(valid_embeddings), | |
"shapes": shapes, | |
"mean_norm": float(np.mean(norms)), | |
"std_norm": float(np.std(norms)) | |
} | |
def analyze_entities( | |
dataset_id: int, | |
db: Session, | |
config: Configuration, | |
) -> None: | |
""" | |
Анализ сущностей в датасете. | |
Args: | |
dataset_id: ID датасета | |
db: Сессия базы данных | |
config: Конфигурация приложения | |
""" | |
# Получаем все сущности | |
entities = ( | |
db.query(EntityModel) | |
.filter(EntityModel.dataset_id == dataset_id) | |
.all() | |
) | |
if not entities: | |
logger.error(f"No entities found for dataset {dataset_id}") | |
return | |
# Базовая статистика | |
logger.info(f"Total entities: {len(entities)}") | |
logger.info(f"Entity types: {set(e.entity_type for e in entities)}") | |
# Статистика по типам | |
type_stats = {} | |
for e in entities: | |
if e.entity_type not in type_stats: | |
type_stats[e.entity_type] = 0 | |
type_stats[e.entity_type] += 1 | |
logger.info("Entities per type:") | |
for t, count in type_stats.items(): | |
logger.info(f" {t}: {count}") | |
# Анализ эмбеддингов | |
embeddings = [e.embedding for e in entities] | |
embedding_stats = analyze_embeddings(embeddings) | |
logger.info("\nEmbedding statistics:") | |
logger.info(f" Total embeddings: {embedding_stats['total']}") | |
logger.info(f" Valid embeddings: {embedding_stats['valid']}") | |
logger.info(" Shapes:") | |
for shape, count in embedding_stats['shapes'].items(): | |
logger.info(f" {shape}: {count}") | |
if embedding_stats['mean_norm'] is not None: | |
logger.info(f" Mean norm: {embedding_stats['mean_norm']:.4f}") | |
logger.info(f" Std norm: {embedding_stats['std_norm']:.4f}") | |
# Анализ текстов | |
text_lengths = [len(e.text) for e in entities] | |
search_text_lengths = [len(e.in_search_text) if e.in_search_text else 0 for e in entities] | |
logger.info("\nText statistics:") | |
logger.info(f" Mean text length: {np.mean(text_lengths):.2f}") | |
logger.info(f" Std text length: {np.std(text_lengths):.2f}") | |
logger.info(f" Mean search text length: {np.mean(search_text_lengths):.2f}") | |
logger.info(f" Std search text length: {np.std(search_text_lengths):.2f}") | |
# Примеры сущностей | |
logger.info("\nExample entities:") | |
for e in entities[:5]: | |
logger.info(f" ID: {e.uuid}") | |
logger.info(f" Name: {e.name}") | |
logger.info(f" Type: {e.entity_type}") | |
logger.info(f" Embedding: {e.embedding}") | |
if e.embedding is not None: | |
logger.info(f" Embedding shape: {e.embedding.shape}") | |
logger.info(" ---") | |
def main() -> None: | |
"""Точка входа скрипта.""" | |
parser = argparse.ArgumentParser(description="Analyze entities in dataset") | |
parser.add_argument("dataset_id", type=int, help="Dataset ID") | |
parser.add_argument( | |
"--config", | |
type=str, | |
default="config_dev.yaml", | |
help="Path to config file", | |
) | |
args = parser.parse_args() | |
config = Configuration(args.config) | |
db = DI.get_db() | |
with db() as session: | |
try: | |
analyze_entities(args.dataset_id, session, config) | |
finally: | |
session.close() | |
if __name__ == "__main__": | |
main() |