solved faiss index issue
Browse files- app/main.py +12 -11
app/main.py
CHANGED
@@ -414,19 +414,17 @@ logger = logging.getLogger(__name__)
|
|
414 |
# Initialize global variables in app state
|
415 |
@app.on_event("startup")
|
416 |
async def startup_event():
|
417 |
-
|
418 |
"""Initialize the application on startup."""
|
419 |
logger = logging.getLogger(__name__)
|
420 |
logger.info("Starting application initialization...")
|
421 |
|
422 |
-
|
423 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
424 |
logger.info(f"Using device: {device}")
|
425 |
|
426 |
if device == "cpu":
|
427 |
logger.warning("GPU not detected. Model will run slower on CPU.")
|
428 |
-
|
429 |
-
|
430 |
# Set NLTK data path
|
431 |
nltk_data_dir = os.environ.get('NLTK_DATA', os.path.join(os.path.expanduser('~'), 'nltk_data'))
|
432 |
os.makedirs(nltk_data_dir, exist_ok=True)
|
@@ -442,8 +440,7 @@ async def startup_event():
|
|
442 |
|
443 |
# Initialize the model and index
|
444 |
try:
|
445 |
-
|
446 |
-
model = pipeline(
|
447 |
"text-generation",
|
448 |
model=MODEL_ID,
|
449 |
trust_remote_code=True,
|
@@ -451,10 +448,14 @@ async def startup_event():
|
|
451 |
device_map="auto",
|
452 |
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
453 |
)
|
454 |
-
embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
455 |
|
456 |
-
# Load or create the FAISS index
|
457 |
faiss_index, documents, embedding_model = await load_or_create_index()
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
logger.info("Application initialization completed successfully")
|
459 |
except Exception as e:
|
460 |
logger.error(f"Error initializing application: {str(e)}")
|
@@ -490,9 +491,9 @@ async def generate_content(request: ContentRequest):
|
|
490 |
|
491 |
response = generate_response_with_rag(
|
492 |
request.topic, # Use topic as the prompt
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
settings
|
497 |
)
|
498 |
|
|
|
414 |
# Initialize global variables in app state
|
415 |
@app.on_event("startup")
|
416 |
async def startup_event():
|
|
|
417 |
"""Initialize the application on startup."""
|
418 |
logger = logging.getLogger(__name__)
|
419 |
logger.info("Starting application initialization...")
|
420 |
|
421 |
+
# Check if CUDA is available
|
422 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
423 |
logger.info(f"Using device: {device}")
|
424 |
|
425 |
if device == "cpu":
|
426 |
logger.warning("GPU not detected. Model will run slower on CPU.")
|
427 |
+
|
|
|
428 |
# Set NLTK data path
|
429 |
nltk_data_dir = os.environ.get('NLTK_DATA', os.path.join(os.path.expanduser('~'), 'nltk_data'))
|
430 |
os.makedirs(nltk_data_dir, exist_ok=True)
|
|
|
440 |
|
441 |
# Initialize the model and index
|
442 |
try:
|
443 |
+
app.state.pipe = pipeline(
|
|
|
444 |
"text-generation",
|
445 |
model=MODEL_ID,
|
446 |
trust_remote_code=True,
|
|
|
448 |
device_map="auto",
|
449 |
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
450 |
)
|
|
|
451 |
|
|
|
452 |
faiss_index, documents, embedding_model = await load_or_create_index()
|
453 |
+
|
454 |
+
# Store these in app.state for access across the application
|
455 |
+
app.state.faiss_index = faiss_index
|
456 |
+
app.state.documents = documents
|
457 |
+
app.state.embedding_model = embedding_model
|
458 |
+
|
459 |
logger.info("Application initialization completed successfully")
|
460 |
except Exception as e:
|
461 |
logger.error(f"Error initializing application: {str(e)}")
|
|
|
491 |
|
492 |
response = generate_response_with_rag(
|
493 |
request.topic, # Use topic as the prompt
|
494 |
+
app.state.faiss_index,
|
495 |
+
app.state.embedding_model,
|
496 |
+
app.state.documents,
|
497 |
settings
|
498 |
)
|
499 |
|