veerukhannan commited on
Commit
cb15139
·
verified ·
1 Parent(s): 3120cc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -15
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from typing import List, Dict, Tuple
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
5
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
6
  import torch
7
  import os
8
  from astrapy.db import AstraDB
@@ -11,6 +11,7 @@ from huggingface_hub import login
11
  import time
12
  import logging
13
  from functools import lru_cache
 
14
 
15
  # Configure logging
16
  logging.basicConfig(
@@ -23,25 +24,26 @@ logger = logging.getLogger(__name__)
23
  load_dotenv()
24
  login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
25
 
26
- # Initialize model with CPU-compatible settings
27
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_name,
30
- device_map="auto",
31
- torch_dtype=torch.float32, # Use float32 for CPU compatibility
32
- )
33
- tokenizer = AutoTokenizer.from_pretrained(model_name)
34
-
35
  class LegalTextSearchBot:
36
  def __init__(self):
37
  try:
 
38
  self.astra_db = AstraDB(
39
  token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
40
  api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT")
41
  )
42
- self.collection = self.astra_db.collection("legal_content")
43
 
44
- # Initialize pipeline with CPU settings
 
 
 
 
 
 
 
 
 
45
  pipe = pipeline(
46
  "text-generation",
47
  model=model,
@@ -54,6 +56,14 @@ class LegalTextSearchBot:
54
  )
55
  self.llm = HuggingFacePipeline(pipeline=pipe)
56
 
 
 
 
 
 
 
 
 
57
  self.template = """
58
  IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
59
 
@@ -81,17 +91,45 @@ class LegalTextSearchBot:
81
  logger.error(f"Error initializing LegalTextSearchBot: {str(e)}")
82
  raise
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @lru_cache(maxsize=100)
85
  def _cached_search(self, query: str) -> tuple:
86
- """Cached version of vector search to improve performance"""
87
  try:
 
 
 
88
  results = list(self.collection.vector_find(
89
- query,
90
  limit=5,
91
  fields=["section_number", "title", "chapter_number", "chapter_title",
92
  "content", "type", "metadata"]
93
  ))
94
- return tuple(results) # Convert to tuple for caching
95
  except Exception as e:
96
  logger.error(f"Error in vector search: {str(e)}")
97
  return tuple()
 
2
  from typing import List, Dict, Tuple
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
5
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel
6
  import torch
7
  import os
8
  from astrapy.db import AstraDB
 
11
  import time
12
  import logging
13
  from functools import lru_cache
14
+ import numpy as np
15
 
16
  # Configure logging
17
  logging.basicConfig(
 
24
  load_dotenv()
25
  login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
26
 
 
 
 
 
 
 
 
 
 
27
  class LegalTextSearchBot:
28
  def __init__(self):
29
  try:
30
+ # Initialize AstraDB connection
31
  self.astra_db = AstraDB(
32
  token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
33
  api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT")
34
  )
35
+ self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION"))
36
 
37
+ # Initialize language model for text generation
38
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_name,
41
+ device_map="auto",
42
+ torch_dtype=torch.float32,
43
+ )
44
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
45
+
46
+ # Initialize text generation pipeline
47
  pipe = pipeline(
48
  "text-generation",
49
  model=model,
 
56
  )
57
  self.llm = HuggingFacePipeline(pipeline=pipe)
58
 
59
+ # Initialize embedding model
60
+ self.embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
61
+ self.embedding_pipeline = pipeline(
62
+ "feature-extraction",
63
+ model=self.embedding_model_name,
64
+ device_map="auto"
65
+ )
66
+
67
  self.template = """
68
  IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
69
 
 
91
  logger.error(f"Error initializing LegalTextSearchBot: {str(e)}")
92
  raise
93
 
94
+ def get_embedding(self, text: str) -> List[float]:
95
+ """Generate embedding vector for text"""
96
+ try:
97
+ # Clean and prepare text
98
+ text = text.replace('\n', ' ').strip()
99
+
100
+ # Generate embedding
101
+ outputs = self.embedding_pipeline(text)
102
+ embeddings = torch.mean(torch.tensor(outputs[0]), dim=0)
103
+
104
+ # Convert to list and ensure correct dimension
105
+ embedding_list = embeddings.tolist()
106
+
107
+ # Pad or truncate to exactly 1024 dimensions
108
+ if len(embedding_list) < 1024:
109
+ embedding_list.extend([0.0] * (1024 - len(embedding_list)))
110
+ elif len(embedding_list) > 1024:
111
+ embedding_list = embedding_list[:1024]
112
+
113
+ return embedding_list
114
+
115
+ except Exception as e:
116
+ logger.error(f"Error generating embedding: {str(e)}")
117
+ raise
118
+
119
  @lru_cache(maxsize=100)
120
  def _cached_search(self, query: str) -> tuple:
121
+ """Cached version of vector search"""
122
  try:
123
+ # Generate embedding for query
124
+ query_embedding = self.get_embedding(query)
125
+
126
  results = list(self.collection.vector_find(
127
+ query_embedding,
128
  limit=5,
129
  fields=["section_number", "title", "chapter_number", "chapter_title",
130
  "content", "type", "metadata"]
131
  ))
132
+ return tuple(results)
133
  except Exception as e:
134
  logger.error(f"Error in vector search: {str(e)}")
135
  return tuple()