Spaces:
Sleeping
Sleeping
# Experimental | |
from pydantic import BaseModel, Field, field_validator | |
from typing import List, Optional, Dict, Union, Any | |
from enum import Enum | |
from uuid import UUID | |
from datetime import datetime | |
from llama_index.core.schema import BaseNode, NodeWithScore | |
from llama_index.core.callbacks.schema import EventPayload | |
from llama_index.core.query_engine.sub_question_query_engine import SubQuestionAnswerPair | |
from db.db import ( | |
MessageRoleEnum, | |
MessageStatusEnum, | |
MessageSubProcessSourceEnum, | |
MessageSubProcessStatusEnum, | |
) | |
DB_DOC_ID_KEY = "db_document_id" | |
class Base(BaseModel): | |
id: Optional[UUID] = Field(None, description="Unique identifier") | |
created_at: Optional[datetime] = Field(None, description="Creation datetime") | |
updated_at: Optional[datetime] = Field(None, description="Update datetime") | |
class Config: | |
orm_mode = True | |
class BaseMetadataObject(BaseModel): | |
class Config: | |
orm_mode = True | |
class Citation(BaseMetadataObject): | |
document_id: UUID | |
text: str | |
page_number: int | |
score: Optional[float] | |
def validate_document_id(cls, value): | |
if value: | |
return str(value) | |
return value | |
def from_node(cls, node_w_score: NodeWithScore) -> "Citation": | |
node: BaseNode = node_w_score.node | |
page_number = int(node.source_node.metadata["page_label"]) | |
document_id = node.source_node.metadata[""] | |
return cls( | |
document_id=document_id, | |
text=node.get_content(), | |
page_number=page_number, | |
score=node_w_score.score, | |
) | |
class QuestionAnswerPair(BaseMetadataObject): | |
""" | |
A question-answer pair that is used to store the sub-questions and answers | |
""" | |
question: str | |
answer: Optional[str] | |
citations: Optional[List[Citation]] = None | |
def from_sub_question_answer_pair( | |
cls, sub_question_answer_pair: SubQuestionAnswerPair | |
): | |
if sub_question_answer_pair.sources is None: | |
citations = None | |
else: | |
citations = [ | |
Citation.from_node(node_w_score) | |
for node_w_score in sub_question_answer_pair.sources | |
if node_w_score.node.source_node is not None | |
and DB_DOC_ID_KEY in node_w_score.node.source_node.metadata | |
] | |
citations = citations or None | |
return cls( | |
question=sub_question_answer_pair.sub_q.sub_question, | |
answer=sub_question_answer_pair.answer, | |
citations=citations, | |
) | |
# later will be Union[QuestionAnswerPair, more to add later... ] | |
class SubProcessMetadataKeysEnum(str, Enum): | |
SUB_QUESTION = EventPayload.SUB_QUESTION.value | |
# keeping the typing pretty loose here, in case there are changes to the metadata data formats. | |
SubProcessMetadataMap = Dict[Union[SubProcessMetadataKeysEnum, str], Any] | |
class MessageSubProcess(Base): | |
message_id: UUID | |
source: MessageSubProcessSourceEnum | |
status: MessageSubProcessStatusEnum | |
metadata_map: Optional[SubProcessMetadataMap] | |
class Message(Base): | |
conversation_id: UUID | |
content: str | |
role: MessageRoleEnum | |
status: MessageStatusEnum | |
sub_processes: List[MessageSubProcess] | |
class UserMessageCreate(BaseModel): | |
content: str | |
class DocumentMetadataKeysEnum(str, Enum): | |
""" | |
Enum for the keys of the metadata map for a document | |
""" | |
SEC_DOCUMENT = "sec_document" | |
class SecDocumentTypeEnum(str, Enum): | |
""" | |
Enum for the type of sec document | |
""" | |
TEN_K = "10-K" | |
TEN_Q = "10-Q" | |
class SecDocumentMetadata(BaseModel): | |
""" | |
Metadata for a document that is a sec document | |
""" | |
company_name: str | |
company_ticker: str | |
doc_type: SecDocumentTypeEnum | |
year: int | |
quarter: Optional[int] | |
accession_number: Optional[str] | |
cik: Optional[str] | |
period_of_report_date: Optional[datetime] | |
filed_as_of_date: Optional[datetime] | |
date_as_of_change: Optional[datetime] | |
DocumentMetadataMap = Dict[Union[DocumentMetadataKeysEnum, str], Any] | |
class Document(Base): | |
url: str | |
metadata_map: Optional[DocumentMetadataMap] = None | |
class Conversation(Base): | |
messages: List[Message] | |
documents: List[Document] | |
class ConversationCreate(BaseModel): | |
document_ids: List[UUID] | |