new_ads_obnova / func_ai.py
Uniaff's picture
Update func_ai.py
4a36ad2 verified
raw
history blame
3.67 kB
# func_ai.py
import requests
import torch
from transformers import pipeline
from deep_translator import GoogleTranslator
import time
import os
from datetime import datetime
VECTOR_API_URL = os.getenv('API_URL')
def log_message(message):
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print(f"[{timestamp}] {message}")
# Инициализация моделей
def init_models():
log_message("Инициализация моделей AI.")
sentiment_model = pipeline(
'sentiment-analysis',
model='cardiffnlp/twitter-xlm-roberta-base-sentiment',
tokenizer='cardiffnlp/twitter-xlm-roberta-base-sentiment',
device=0 if torch.cuda.is_available() else -1
)
classifier = pipeline(
"zero-shot-classification",
model="valhalla/distilbart-mnli-12-6",
device=0 if torch.cuda.is_available() else -1
)
return sentiment_model, classifier
sentiment_model, classifier = init_models()
def classify_comment(text):
if not text:
log_message("Получен пустой текст для классификации.")
return "non-interrogative"
log_message(f"Классификация комментария: {text}")
try:
translated_text = GoogleTranslator(source='auto', target="en").translate(text)
log_message(f"Переведенный текст: {translated_text}")
except Exception as e:
log_message(f"Ошибка при переводе: {e}")
return "non-interrogative"
if not translated_text:
log_message("Перевод вернул пустой текст.")
return "non-interrogative"
try:
result = classifier(translated_text, ["interrogative", "non-interrogative"], clean_up_tokenization_spaces=True)
log_message(f"Результат классификации: {result}")
except Exception as e:
log_message(f"Ошибка при классификации: {e}")
return "non-interrogative"
top_class = result['labels'][0]
log_message(f"Верхний класс: {top_class}")
return top_class
def analyze_sentiment(comments):
log_message("Начинаем анализ настроений.")
results = []
for i in range(0, len(comments), 50):
batch = comments[i:i + 50]
log_message(f"Анализируем батч с {i} по {i + len(batch)} комментарий: {batch}")
try:
batch_results = sentiment_model(batch)
log_message(f"Результаты батча: {batch_results}")
results.extend(batch_results)
except Exception as e:
log_message(f"Ошибка при анализе настроений батча: {e}")
time.sleep(1) # Задержка для предотвращения перегрузки
log_message("Анализ настроений завершен.")
return results
def retrieve_from_vdb(query):
log_message(f"Отправка запроса к FastAPI сервису: {query}")
try:
response = requests.post(f"{VECTOR_API_URL}/search/", json={"query": query})
if response.status_code == 200:
results = response.json().get("results", [])
log_message(f"Получено {len(results)} результатов: {results}")
return results
else:
log_message(f"Ошибка при поиске: {response.text}")
return []
except Exception as e:
log_message(f"Ошибка при запросе к векторной базе данных: {e}")
return []