vhr1007 commited on
Commit
7d3c394
·
1 Parent(s): 658ace0
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -3,7 +3,6 @@ from fastapi import FastAPI, Depends, HTTPException
3
  import logging
4
  from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModel
6
- from sentence_transformers import models, SentenceTransformer
7
  from services.qdrant_searcher import QdrantSearcher
8
  from services.openai_service import generate_rag_response
9
  from utils.auth import token_required
@@ -46,7 +45,7 @@ access_token = os.getenv('QDRANT_ACCESS_TOKEN')
46
  if not qdrant_url or not access_token:
47
  raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.")
48
 
49
- # Initialize the SentenceTransformer model with trust_remote_code using transformers
50
  try:
51
  cache_folder = os.path.join(hf_home_dir, "transformers_cache")
52
 
@@ -54,18 +53,17 @@ try:
54
  tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
55
  model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
56
 
57
- # Wrap the model into a SentenceTransformer
58
- word_embedding_model = models.Transformer(model_name_or_path='nomic-ai/nomic-embed-text-v1.5', model=model, tokenizer=tokenizer)
59
- pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
60
- encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
61
-
62
- logging.info("Successfully loaded the SentenceTransformer model.")
63
  except Exception as e:
64
- logging.error(f"Failed to load the SentenceTransformer model: {e}")
65
- raise HTTPException(status_code=500, detail="Failed to load the SentenceTransformer model.")
66
 
67
- # Initialize the Qdrant searcher
68
- searcher = QdrantSearcher(encoder, qdrant_url, access_token)
 
 
 
 
69
 
70
  # Define the request body models
71
  class SearchDocumentsRequest(BaseModel):
@@ -120,6 +118,10 @@ async def generate_rag_response_api(
120
  logging.error(f"Search documents error: {error}")
121
  raise HTTPException(status_code=500, detail=error)
122
 
 
 
 
 
123
  response, error = generate_rag_response(hits, body.search_query)
124
 
125
  if error:
 
3
  import logging
4
  from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModel
 
6
  from services.qdrant_searcher import QdrantSearcher
7
  from services.openai_service import generate_rag_response
8
  from utils.auth import token_required
 
45
  if not qdrant_url or not access_token:
46
  raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.")
47
 
48
+ # Load the model and tokenizer with trust_remote_code=True
49
  try:
50
  cache_folder = os.path.join(hf_home_dir, "transformers_cache")
51
 
 
53
  tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
54
  model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
55
 
56
+ logging.info("Successfully loaded the model and tokenizer with transformers.")
 
 
 
 
 
57
  except Exception as e:
58
+ logging.error(f"Failed to load the model: {e}")
59
+ raise HTTPException(status_code=500, detail="Failed to load the custom model.")
60
 
61
+ # Function to embed text using the model
62
+ def embed_texts(texts):
63
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
64
+ outputs = model(**inputs)
65
+ embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling
66
+ return embeddings
67
 
68
  # Define the request body models
69
  class SearchDocumentsRequest(BaseModel):
 
118
  logging.error(f"Search documents error: {error}")
119
  raise HTTPException(status_code=500, detail=error)
120
 
121
+ # Example: Use custom embedding logic
122
+ # embeddings = embed_texts([hit['text'] for hit in hits])
123
+ # Use embeddings for further processing...
124
+
125
  response, error = generate_rag_response(hits, body.search_query)
126
 
127
  if error: