Spaces:
Building
Building
import os | |
import time | |
import numpy as np | |
from google import genai | |
from openai import OpenAI | |
import time | |
import random | |
from openai import RateLimitError | |
from functools import wraps | |
from google.genai import types | |
from pydantic import BaseModel | |
from concurrent.futures import ThreadPoolExecutor | |
from html_chunking import get_html_chunks | |
from abc import ABC, abstractmethod | |
from typing import List, Any, Dict, Tuple, Optional | |
import re | |
import json | |
from langchain_text_splitters import HTMLHeaderTextSplitter | |
from sentence_transformers import SentenceTransformer | |
class LLMClient(ABC): | |
""" | |
Abstract base class for calling LLM APIs. | |
""" | |
def __init__(self, config: dict = None): | |
""" | |
Initializes the LLMClient with a configuration dictionary. | |
Args: | |
config (dict): Configuration settings for the LLM client. | |
""" | |
self.config = config or {} | |
def call_api(self, prompt: str) -> str: | |
""" | |
Call the underlying LLM API with the given prompt. | |
Args: | |
prompt (str): The prompt or input text for the LLM. | |
Returns: | |
str: The response from the LLM. | |
""" | |
pass | |
class GeminiLLMClient(LLMClient): | |
""" | |
Concrete implementation of LLMClient for the Gemini API. | |
""" | |
def __init__(self, config: dict): | |
""" | |
Initializes the GeminiLLMClient with an API key, model name, and optional generation settings. | |
Args: | |
config (dict): Configuration containing: | |
- 'api_key': (optional) API key for Gemini (falls back to GEMINI_API_KEY env var) | |
- 'model_name': (optional) the model to use (default 'gemini-2.0-flash') | |
- 'generation_config': (optional) dict of GenerateContentConfig parameters | |
""" | |
api_key = config.get("api_key") or os.environ.get("GEMINI_API_KEY") | |
if not api_key: | |
raise ValueError( | |
"API key for Gemini must be provided in config['api_key'] or GEMINI_API_KEY env var." | |
) | |
self.client = genai.Client(api_key=api_key) | |
self.model_name = config.get("model_name", "gemini-2.0-flash") | |
# allow custom generation settings, fallback to sensible defaults | |
gen_conf = config.get("generation_config", {}) | |
self.generate_config = types.GenerateContentConfig( | |
response_mime_type=gen_conf.get("response_mime_type", "text/plain"), | |
temperature=gen_conf.get("temperature"), | |
max_output_tokens=gen_conf.get("max_output_tokens"), | |
top_p=gen_conf.get("top_p"), | |
top_k=gen_conf.get("top_k"), | |
# add any other fields you want to expose | |
) | |
def call_api(self, prompt: str) -> str: | |
""" | |
Call the Gemini API with the given prompt (non-streaming). | |
Args: | |
prompt (str): The input text for the API. | |
Returns: | |
str: The generated text from the Gemini API. | |
""" | |
contents = [ | |
types.Content( | |
role="user", | |
parts=[types.Part.from_text(text=prompt)], | |
) | |
] | |
# Non-streaming call returns a full response object | |
response = self.client.models.generate_content( | |
model=self.model_name, | |
contents=contents, | |
config=self.generate_config, | |
) | |
# Combine all output parts into a single string | |
return response.text | |
def extract_markdown_json(text: str) -> Optional[Dict[str, Any]]: | |
""" | |
Find the first Markdown ```json ...``` block in `text`, | |
parse it as JSON, and return the resulting dict. | |
Returns None if no valid JSON block is found. | |
""" | |
# 1) Look specifically for a ```json code fence | |
fence_match = re.search( | |
r"```json\s*(\{.*?\})\s*```", | |
text, | |
re.DOTALL | re.IGNORECASE | |
) | |
if not fence_match: | |
return None | |
json_str = fence_match.group(1) | |
try: | |
return json.loads(json_str) | |
except json.JSONDecodeError: | |
return None | |
def retry_on_ratelimit(max_retries=5, base_delay=1.0, max_delay=10.0): | |
def deco(fn): | |
def wrapped(*args, **kwargs): | |
delay = base_delay | |
for attempt in range(max_retries): | |
try: | |
return fn(*args, **kwargs) | |
except RateLimitError: | |
if attempt == max_retries - 1: | |
# give up | |
raise | |
# back off + jitter | |
sleep = min(max_delay, delay) + random.uniform(0, delay) | |
time.sleep(sleep) | |
delay *= 2 | |
# unreachable | |
return wrapped | |
return deco | |
class NvidiaLLMClient(LLMClient): | |
""" | |
Concrete implementation of LLMClient for the NVIDIA API (non-streaming). | |
""" | |
def __init__(self, config: dict): | |
""" | |
Initializes the NvidiaLLMClient with an API key, model name, and optional generation settings. | |
Args: | |
config (dict): Configuration containing: | |
- 'api_key': (optional) API key for NVIDIA (falls back to NVIDIA_API_KEY env var) | |
- 'model_name': (optional) the model to use (default 'google/gemma-3-1b-it') | |
- 'generation_config': (optional) dict of generation parameters like temperature, top_p, etc. | |
""" | |
api_key = config.get("api_key") or os.environ.get("NVIDIA_API_KEY") | |
if not api_key: | |
raise ValueError( | |
"API key for NVIDIA must be provided in config['api_key'] or NVIDIA_API_KEY env var." | |
) | |
self.client = OpenAI( | |
base_url="https://integrate.api.nvidia.com/v1", | |
api_key=api_key | |
) | |
self.model_name = config.get("model_name", "google/gemma-3-1b-it") | |
# Store generation settings with sensible defaults | |
gen_conf = config.get("generation_config", {}) | |
self.temperature = gen_conf.get("temperature", 0.1) | |
self.top_p = gen_conf.get("top_p", 0.7) | |
self.max_tokens = gen_conf.get("max_tokens", 512) | |
def set_model(self, model_name: str): | |
""" | |
Set the model name for the NVIDIA API client. | |
Args: | |
model_name (str): The name of the model to use. | |
""" | |
self.model_name = model_name | |
def call_api(self, prompt: str) -> str: | |
""" | |
Call the NVIDIA API with the given prompt (non-streaming). | |
Args: | |
prompt (str): The input text for the API. | |
Returns: | |
str: The generated text from the NVIDIA API. | |
""" | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=[{"role": "user", "content": prompt}], | |
temperature=self.temperature, | |
top_p=self.top_p, | |
max_tokens=self.max_tokens | |
# stream is omitted (defaults to False) | |
) | |
# print("DONE") | |
# For the standard (non-streaming) response: | |
# choices[0].message.content holds the generated text | |
return response.choices[0].message.content | |
def call_batch(self, prompts, max_workers=8): | |
""" | |
Parallel batch with isolated errors: each prompt that still | |
fails after retries will raise, but others succeed. | |
""" | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
results = [None] * len(prompts) | |
with ThreadPoolExecutor(max_workers=max_workers) as ex: | |
futures = {ex.submit(self.call_api, p): i for i, p in enumerate(prompts)} | |
for fut in as_completed(futures): | |
idx = futures[fut] | |
try: | |
results[idx] = fut.result() | |
except RateLimitError: | |
# You could set results[idx] = None or a default string | |
results[idx] = f"<failed after retries>" | |
return results | |
class AIExtractor: | |
def __init__(self, llm_client: LLMClient, prompt_template: str): | |
""" | |
Initializes the AIExtractor with a specific LLM client and configuration. | |
Args: | |
llm_client (LLMClient): An instance of a class that implements the LLMClient interface. | |
prompt_template (str): The template to use for generating prompts for the LLM. | |
should contain placeholders for dynamic content. | |
e.g., "Extract the following information: {content} based on schema: {schema}" | |
""" | |
self.llm_client = llm_client | |
self.prompt_template = prompt_template | |
def extract(self, content: str, schema: BaseModel) -> str: | |
""" | |
Extracts structured information from the given content based on the provided schema. | |
Args: | |
content (str): The raw content to extract information from. | |
schema (BaseModel): A Pydantic model defining the structure of the expected output. | |
Returns: | |
str: The structured JSON object as a string. | |
""" | |
prompt = self.prompt_template.format(content=content, schema=schema.model_json_schema()) | |
# print(f"Generated prompt: {prompt}") | |
response = self.llm_client.call_api(prompt) | |
return response | |
class LLMClassifierExtractor(AIExtractor): | |
""" | |
Extractor that uses an LLM to classify and extract structured information from text content. | |
This class is designed to handle classification tasks where the LLM generates structured output based on a provided schema. | |
""" | |
def __init__(self, llm_client: LLMClient, prompt_template: str, classifier_prompt: str, ): | |
""" | |
Initializes the LLMClassifierExtractor with an LLM client and a prompt template. | |
Args: | |
llm_client (LLMClient): An instance of a class that implements the LLMClient interface. | |
prompt_template (str): The template to use for generating prompts for the LLM. | |
""" | |
super().__init__(llm_client, prompt_template) | |
self.classifier_prompt = classifier_prompt | |
def chunk_content(self, content: str , max_tokens: int = 500, is_clean: bool = True) -> List[str]: | |
""" | |
Splits the content into manageable chunks for processing. | |
Args: | |
content (str): The raw content to be chunked. | |
Returns: | |
List[str]: A list of text chunks. | |
""" | |
# Use the get_html_chunks function to split the content into chunks | |
return get_html_chunks(html=content, max_tokens=max_tokens, is_clean_html=is_clean, attr_cutoff_len=5) | |
def classify_chunks(self, chunks: List[str], schema: BaseModel) -> List[Dict[str, Any]]: | |
""" | |
Classifies each chunk using the LLM based on the provided schema. | |
Args: | |
chunks (List[str]): A list of text chunks to classify. | |
schema (BaseModel): A Pydantic model defining the structure of the expected output. | |
Returns: | |
List[Dict[str, Any]]: A list of dictionaries containing classified information. | |
""" | |
prompts = [self.classifier_prompt.format(content=chunk, schema=schema.model_json_schema()) for chunk in chunks] | |
classified_chunks = [] | |
responses = self.llm_client.call_batch(prompts) | |
for response in responses: | |
# extract the json from the response | |
json_data = extract_markdown_json(response) | |
if json_data: | |
classified_chunks.append(json_data) | |
else: | |
classified_chunks.append({ | |
"error": "Failed to extract JSON from response", | |
"relevant": 1, | |
}) | |
return classified_chunks | |
def extract(self, content: str, schema: BaseModel) -> str: | |
""" | |
Extracts structured information from the given content based on the provided schema. | |
Args: | |
content (str): The raw content to extract information from. | |
schema (BaseModel): A Pydantic model defining the structure of the expected output. | |
Returns: | |
str: The structured JSON object as a string. | |
""" | |
# Chunk the HTML | |
chunks = self.chunk_content(content,max_tokens=1500) | |
print(f"Content successfully chunked into {len(chunks)} pieces.") | |
# Classify each chunk using the LLM | |
classified_chunks = self.classify_chunks(chunks, schema) | |
# Concatenate the positive classified chunks into a single string | |
print(f"Classified {classified_chunks} chunks.") | |
positive_chunks = [] | |
for i, chunk in enumerate(classified_chunks): | |
if chunk.get("relevant", 0) > 0: | |
positive_chunks.append(chunks[i]) | |
if len(positive_chunks) == 0: | |
positive_chunks = chunks | |
filtered_content = "\n\n".join(positive_chunks) | |
print(f"Filtered content for extraction: {filtered_content}") # Log the first 500 characters of filtered content | |
if not filtered_content: | |
print("Warning: No relevant chunks found. Returning empty response.") | |
return "{}" | |
# Generate the final prompt for extraction | |
prompt = self.prompt_template.format(content=filtered_content, schema=schema.model_json_schema()) | |
print(f"Generated prompt for extraction: {prompt[:500]}...") | |
# Call the LLM to extract structured information | |
llm_response = self.llm_client.call_api(prompt) | |
print(f"LLM response: {llm_response[:500]}...") | |
# Return the structured response | |
if not llm_response: | |
print("Warning: LLM response is empty. Returning empty response.") | |
return "{}" | |
# json_response = extract_markdown_json(llm_response) | |
# if json_response is None: | |
# print("Warning: Failed to extract JSON from LLM response. Returning empty response.") | |
# return "{}" | |
return llm_response | |
# TODO: RAGExtractor class | |
class RAGExtractor(AIExtractor): | |
""" | |
RAG-enhanced extractor that uses similarity search to find relevant chunks | |
before performing extraction, utilizing HTML header-based chunking and SentenceTransformer embeddings. | |
""" | |
def __init__(self, | |
llm_client: LLMClient, | |
prompt_template: str, | |
embedding_model_path: str = "sentence-transformers/all-mpnet-base-v2", | |
top_k: int = 3): | |
""" | |
Initialize RAG extractor with embedding and chunking capabilities. | |
Args: | |
llm_client: LLM client for generation. | |
prompt_template: Template for prompts. | |
embedding_model_path: Path/name for the SentenceTransformer embedding model. | |
top_k: Number of top similar chunks to retrieve. | |
""" | |
super().__init__(llm_client, prompt_template) | |
self.embedding_model_path = embedding_model_path | |
# Initialize the SentenceTransformer model for embeddings | |
self.embedding_model_instance = SentenceTransformer(self.embedding_model_path) | |
self.top_k = top_k | |
def _langchain_HHTS(text: str) -> List[str]: | |
""" | |
Chunks HTML text using Langchain's HTMLHeaderTextSplitter based on h1 and h2 headers. | |
Args: | |
text (str): The HTML content to chunk. | |
Returns: | |
List[str]: A list of chunked text strings (extracted from Document objects' page_content). | |
""" | |
headers_to_split_on = [ | |
("h1", "Header 1"), | |
("h2", "Header 2"), | |
# ("h3", "Header 3"), # This header was explicitly commented out in the request | |
] | |
html_splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on) | |
return [doc.page_content for doc in html_splitter.split_text(text)] | |
def embed_text(self, text: str) -> np.ndarray: | |
""" | |
Generate embeddings for text using the initialized SentenceTransformer model. | |
Args: | |
text: The text string to embed. | |
Returns: | |
np.ndarray: The embedding vector for the input text as a NumPy array. | |
""" | |
try: | |
return self.embedding_model_instance.encode(text) | |
except Exception as e: | |
print(f"Warning: Embedding failed for text: '{text[:50]}...', using random embedding: {e}") | |
return None | |
def search_similar_chunks(self, | |
query: str, | |
chunks: List[str], | |
embeddings: np.ndarray) -> List[str]: | |
""" | |
Find the most similar chunks to the query within the given list of chunks | |
by calculating cosine similarity between their embeddings. | |
Args: | |
query (str): The query text whose embedding will be used for similarity comparison. | |
chunks (List[str]): A list of text chunks to search within. | |
embeddings (np.ndarray): Precomputed embeddings for the chunks, corresponding to the 'chunks' list. | |
Returns: | |
List[str]: A list of the 'top_k' most similar chunks to the query. | |
""" | |
query_embedding = self.embed_text(query) | |
similarities = [] | |
if query_embedding.ndim > 1: | |
query_embedding = query_embedding.flatten() | |
for i, chunk_embedding in enumerate(embeddings): | |
if chunk_embedding.ndim > 1: | |
chunk_embedding = chunk_embedding.flatten() | |
norm_query = np.linalg.norm(query_embedding) | |
norm_chunk = np.linalg.norm(chunk_embedding) | |
if norm_query == 0 or norm_chunk == 0: | |
similarity = 0.0 | |
else: | |
similarity = np.dot(query_embedding, chunk_embedding) / (norm_query * norm_chunk) | |
similarities.append((similarity, i)) | |
similarities.sort(key=lambda x: x[0], reverse=True) | |
top_indices = [idx for _, idx in similarities[:self.top_k]] | |
return [chunks[i] for i in top_indices] | |
def extract(self, content: str, schema: BaseModel, query: str = None) -> str: | |
""" | |
Overrides the base AIExtractor's method to implement RAG-enhanced extraction. | |
This function first chunks the input HTML content, then uses a query to find | |
the most relevant chunks via embedding similarity, and finally sends these | |
relevant chunks as context to the LLM for structured information extraction. | |
Args: | |
content (str): The raw HTML content from which to extract information. | |
schema (BaseModel): A Pydantic model defining the desired output structure for the LLM. | |
query (str, optional): An optional query string to guide the retrieval of relevant chunks. | |
If not provided, a default query based on the schema will be used. | |
Returns: | |
str: The structured JSON object as a string, as generated by the LLM. | |
""" | |
start_time = time.time() | |
if not query: | |
query = f"Extract information based on the following JSON schema: {schema.model_json_schema()}" | |
print(f"No explicit query provided for retrieval. Using default: '{query[:100]}...'") | |
chunks = self._langchain_HHTS(content) | |
print(f"Content successfully chunked into {len(chunks)} pieces.") | |
combined_content_for_llm = "" | |
if not chunks: | |
print("Warning: No chunks were generated from the provided content. The entire original content will be sent to the LLM.") | |
combined_content_for_llm = content | |
else: | |
chunk_embeddings = np.array([self.embed_text(chunk) for chunk in chunks]) | |
print(f"Generated embeddings for {len(chunks)} chunks.") | |
similar_chunks = self.search_similar_chunks(query, chunks, chunk_embeddings) | |
print(f"Retrieved {len(similar_chunks)} similar chunks based on the query.") | |
combined_content_for_llm = "\n\n".join(similar_chunks) | |
print(f"Combined content for LLM (truncated): '{combined_content_for_llm[:200]}...'") | |
prompt = self.prompt_template.format(content=combined_content_for_llm, schema=schema.model_json_schema()) | |
print(f"Sending prompt to LLM (truncated): '{prompt[:500]}...'") | |
llm_response = self.llm_client.call_api(prompt) | |
execution_time = (time.time() - start_time) * 1000 | |
print(f"Extraction process completed in {execution_time:.2f} milliseconds.") | |
print(f"LLM's final response: {llm_response}") | |
print("=" * 78) | |
return llm_response |