Spaces:
Sleeping
Sleeping
Commit
·
d8c6d94
1
Parent(s):
92d8c87
update cache
Browse files- src/demo/asg_retriever.py +25 -5
- src/demo/category_and_tsne.py +17 -4
- src/demo/main.py +9 -1
- src/demo/path_utils.py +15 -0
- src/demo/survey_generation_pipeline/asg_retriever.py +72 -42
- src/demo/survey_generation_pipeline/category_and_tsne.py +25 -7
- src/demo/survey_generation_pipeline/main.py +17 -9
- src/demo/survey_generator_api.py +75 -124
- src/demo/views.py +13 -3
src/demo/asg_retriever.py
CHANGED
@@ -8,7 +8,11 @@ from .asg_splitter import TextSplitting
|
|
8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
9 |
import time
|
10 |
import concurrent.futures
|
11 |
-
from .path_utils import get_path
|
|
|
|
|
|
|
|
|
12 |
|
13 |
class Retriever:
|
14 |
client = None
|
@@ -201,7 +205,11 @@ def process_pdf(file_path: str, survey_id: str, embedder: HuggingFaceEmbeddings,
|
|
201 |
return collection_name, embeddings_list, documents_list, metadata_list,title_new
|
202 |
|
203 |
def query_embeddings(collection_name: str, query_list: list):
|
204 |
-
|
|
|
|
|
|
|
|
|
205 |
retriever = Retriever()
|
206 |
|
207 |
final_context = ""
|
@@ -222,7 +230,11 @@ def query_embeddings(collection_name: str, query_list: list):
|
|
222 |
|
223 |
# new, may be in parallel
|
224 |
def query_embeddings_new(collection_name: str, query_list: list):
|
225 |
-
|
|
|
|
|
|
|
|
|
226 |
retriever = Retriever()
|
227 |
|
228 |
final_context = ""
|
@@ -250,7 +262,11 @@ def query_embeddings_new(collection_name: str, query_list: list):
|
|
250 |
|
251 |
# wza
|
252 |
def query_embeddings_new_new(collection_name: str, query_list: list):
|
253 |
-
|
|
|
|
|
|
|
|
|
254 |
retriever = Retriever()
|
255 |
|
256 |
final_context = "" # Stores concatenated context
|
@@ -313,7 +329,11 @@ def query_multiple_collections(collection_names: list[str], query_list: list[str
|
|
313 |
dict: Combined results from all collections, grouped by collection.
|
314 |
"""
|
315 |
# Define embedder inside the function
|
316 |
-
|
|
|
|
|
|
|
|
|
317 |
retriever = Retriever()
|
318 |
|
319 |
def query_single_collection(collection_name: str):
|
|
|
8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
9 |
import time
|
10 |
import concurrent.futures
|
11 |
+
from .path_utils import get_path, setup_hf_cache
|
12 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
13 |
+
|
14 |
+
# 设置 Hugging Face 缓存目录
|
15 |
+
cache_dir = setup_hf_cache()
|
16 |
|
17 |
class Retriever:
|
18 |
client = None
|
|
|
205 |
return collection_name, embeddings_list, documents_list, metadata_list,title_new
|
206 |
|
207 |
def query_embeddings(collection_name: str, query_list: list):
|
208 |
+
try:
|
209 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
210 |
+
except Exception as e:
|
211 |
+
print(f"Error initializing embedder: {e}")
|
212 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
213 |
retriever = Retriever()
|
214 |
|
215 |
final_context = ""
|
|
|
230 |
|
231 |
# new, may be in parallel
|
232 |
def query_embeddings_new(collection_name: str, query_list: list):
|
233 |
+
try:
|
234 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
235 |
+
except Exception as e:
|
236 |
+
print(f"Error initializing embedder: {e}")
|
237 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
238 |
retriever = Retriever()
|
239 |
|
240 |
final_context = ""
|
|
|
262 |
|
263 |
# wza
|
264 |
def query_embeddings_new_new(collection_name: str, query_list: list):
|
265 |
+
try:
|
266 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
267 |
+
except Exception as e:
|
268 |
+
print(f"Error initializing embedder: {e}")
|
269 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
270 |
retriever = Retriever()
|
271 |
|
272 |
final_context = "" # Stores concatenated context
|
|
|
329 |
dict: Combined results from all collections, grouped by collection.
|
330 |
"""
|
331 |
# Define embedder inside the function
|
332 |
+
try:
|
333 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
334 |
+
except Exception as e:
|
335 |
+
print(f"Error initializing embedder: {e}")
|
336 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
337 |
retriever = Retriever()
|
338 |
|
339 |
def query_single_collection(collection_name: str):
|
src/demo/category_and_tsne.py
CHANGED
@@ -7,6 +7,8 @@ import seaborn as sns
|
|
7 |
import json
|
8 |
from sklearn.manifold import TSNE
|
9 |
from sklearn.cluster import AgglomerativeClustering
|
|
|
|
|
10 |
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
from bertopic import BERTopic
|
@@ -14,7 +16,7 @@ from bertopic.representation import KeyBERTInspired
|
|
14 |
from sklearn.feature_extraction.text import CountVectorizer
|
15 |
from bertopic.vectorizers import ClassTfidfTransformer
|
16 |
from umap import UMAP
|
17 |
-
from .path_utils import get_path
|
18 |
|
19 |
plt.switch_backend('agg')
|
20 |
device = 0
|
@@ -35,6 +37,9 @@ import matplotlib.pyplot as plt
|
|
35 |
from sklearn.manifold import TSNE
|
36 |
import seaborn as sns
|
37 |
|
|
|
|
|
|
|
38 |
class DimensionalityReduction:
|
39 |
def fit(self, X):
|
40 |
return self
|
@@ -44,7 +49,11 @@ class DimensionalityReduction:
|
|
44 |
|
45 |
class ClusteringWithTopic:
|
46 |
def __init__(self, df, n_topics=3):
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
# umap_model = DimensionalityReduction()
|
49 |
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', init = 'pca')
|
50 |
hdbscan_model = AgglomerativeClustering(n_clusters=n_topics)
|
@@ -81,7 +90,11 @@ class ClusteringWithTopic:
|
|
81 |
初始化 ClusteringWithTopic,接受一个 n_topics_list,其中包含多个聚类数目,
|
82 |
选取 silhouette_score 最高的结果。
|
83 |
"""
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
self.embeddings = embedding_model.encode(df, show_progress_bar=True)
|
86 |
|
87 |
self.df = df
|
@@ -97,7 +110,7 @@ class ClusteringWithTopic:
|
|
97 |
# 用于存储不同聚类数目的结果
|
98 |
self.best_n_topics = None
|
99 |
self.best_labels = None
|
100 |
-
self.best_score = -1
|
101 |
# def fit_and_get_labels(self, X):
|
102 |
# topics, probs = self.topic_model.fit_transform(self.df, self.embeddings)
|
103 |
# return topics
|
|
|
7 |
import json
|
8 |
from sklearn.manifold import TSNE
|
9 |
from sklearn.cluster import AgglomerativeClustering
|
10 |
+
import os
|
11 |
+
import tempfile
|
12 |
|
13 |
from sentence_transformers import SentenceTransformer
|
14 |
from bertopic import BERTopic
|
|
|
16 |
from sklearn.feature_extraction.text import CountVectorizer
|
17 |
from bertopic.vectorizers import ClassTfidfTransformer
|
18 |
from umap import UMAP
|
19 |
+
from .path_utils import get_path, setup_hf_cache
|
20 |
|
21 |
plt.switch_backend('agg')
|
22 |
device = 0
|
|
|
37 |
from sklearn.manifold import TSNE
|
38 |
import seaborn as sns
|
39 |
|
40 |
+
# 设置 Hugging Face 缓存目录
|
41 |
+
cache_dir = setup_hf_cache()
|
42 |
+
|
43 |
class DimensionalityReduction:
|
44 |
def fit(self, X):
|
45 |
return self
|
|
|
49 |
|
50 |
class ClusteringWithTopic:
|
51 |
def __init__(self, df, n_topics=3):
|
52 |
+
try:
|
53 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True, cache_folder=cache_dir)
|
54 |
+
except Exception as e:
|
55 |
+
print(f"Error initializing SentenceTransformer: {e}")
|
56 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
|
57 |
# umap_model = DimensionalityReduction()
|
58 |
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', init = 'pca')
|
59 |
hdbscan_model = AgglomerativeClustering(n_clusters=n_topics)
|
|
|
90 |
初始化 ClusteringWithTopic,接受一个 n_topics_list,其中包含多个聚类数目,
|
91 |
选取 silhouette_score 最高的结果。
|
92 |
"""
|
93 |
+
try:
|
94 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True, cache_folder=cache_dir)
|
95 |
+
except Exception as e:
|
96 |
+
print(f"Error initializing SentenceTransformer: {e}")
|
97 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
|
98 |
self.embeddings = embedding_model.encode(df, show_progress_bar=True)
|
99 |
|
100 |
self.df = df
|
|
|
110 |
# 用于存储不同聚类数目的结果
|
111 |
self.best_n_topics = None
|
112 |
self.best_labels = None
|
113 |
+
self.best_score = -1
|
114 |
# def fit_and_get_labels(self, X):
|
115 |
# topics, probs = self.topic_model.fit_transform(self.df, self.embeddings)
|
116 |
# return topics
|
src/demo/main.py
CHANGED
@@ -20,6 +20,10 @@ from asg_outline import OutlineGenerator, generateSurvey_qwen_new
|
|
20 |
import os
|
21 |
from markdown_pdf import MarkdownPdf, Section # Assuming you are using markdown_pdf
|
22 |
from typing import Any
|
|
|
|
|
|
|
|
|
23 |
|
24 |
def clean_str(input_str):
|
25 |
input_str = str(input_str).strip().lower()
|
@@ -135,7 +139,11 @@ class ASG_system:
|
|
135 |
|
136 |
|
137 |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
138 |
-
|
|
|
|
|
|
|
|
|
139 |
self.pipeline = transformers.pipeline(
|
140 |
"text-generation",
|
141 |
model=model_id,
|
|
|
20 |
import os
|
21 |
from markdown_pdf import MarkdownPdf, Section # Assuming you are using markdown_pdf
|
22 |
from typing import Any
|
23 |
+
from .path_utils import get_path, setup_hf_cache
|
24 |
+
|
25 |
+
# 设置 Hugging Face 缓存目录
|
26 |
+
cache_dir = setup_hf_cache()
|
27 |
|
28 |
def clean_str(input_str):
|
29 |
input_str = str(input_str).strip().lower()
|
|
|
139 |
|
140 |
|
141 |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
142 |
+
try:
|
143 |
+
self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
144 |
+
except Exception as e:
|
145 |
+
print(f"Error initializing embedder: {e}")
|
146 |
+
self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
147 |
self.pipeline = transformers.pipeline(
|
148 |
"text-generation",
|
149 |
model=model_id,
|
src/demo/path_utils.py
CHANGED
@@ -1,6 +1,21 @@
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# 检查是否在 Hugging Face Spaces 环境中
|
5 |
def get_data_paths():
|
6 |
# 如果在 Hugging Face Spaces 中,使用临时目录
|
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
|
4 |
+
# 设置 Hugging Face 缓存目录
|
5 |
+
def setup_hf_cache():
|
6 |
+
"""设置 Hugging Face 缓存目录,在 Hugging Face Spaces 中使用临时目录"""
|
7 |
+
if os.environ.get('SPACE_ID') or os.environ.get('HF_SPACE_ID'):
|
8 |
+
# 在 Hugging Face Spaces 中使用临时目录作为缓存
|
9 |
+
cache_dir = tempfile.mkdtemp()
|
10 |
+
os.environ['HF_HOME'] = cache_dir
|
11 |
+
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_dir, 'transformers')
|
12 |
+
os.environ['HF_HUB_CACHE'] = os.path.join(cache_dir, 'hub')
|
13 |
+
print(f"Using Hugging Face cache directory: {cache_dir}")
|
14 |
+
return cache_dir
|
15 |
+
else:
|
16 |
+
# 本地环境使用默认缓存目录
|
17 |
+
return None
|
18 |
+
|
19 |
# 检查是否在 Hugging Face Spaces 环境中
|
20 |
def get_data_paths():
|
21 |
# 如果在 Hugging Face Spaces 中,使用临时目录
|
src/demo/survey_generation_pipeline/asg_retriever.py
CHANGED
@@ -8,7 +8,11 @@ from .asg_splitter import TextSplitting
|
|
8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
9 |
import time
|
10 |
import concurrent.futures
|
11 |
-
from ..path_utils import get_path
|
|
|
|
|
|
|
|
|
12 |
|
13 |
class Retriever:
|
14 |
client = None
|
@@ -223,7 +227,11 @@ def process_pdf(file_path: str, survey_id: str, embedder: HuggingFaceEmbeddings,
|
|
223 |
return collection_name, embeddings_list, documents_list, metadata_list,title_new
|
224 |
|
225 |
def query_embeddings(collection_name: str, query_list: list):
|
226 |
-
|
|
|
|
|
|
|
|
|
227 |
retriever = Retriever()
|
228 |
|
229 |
final_context = ""
|
@@ -244,7 +252,11 @@ def query_embeddings(collection_name: str, query_list: list):
|
|
244 |
|
245 |
# new, may be in parallel
|
246 |
def query_embeddings_new(collection_name: str, query_list: list):
|
247 |
-
|
|
|
|
|
|
|
|
|
248 |
retriever = Retriever()
|
249 |
|
250 |
final_context = ""
|
@@ -270,45 +282,59 @@ def query_embeddings_new(collection_name: str, query_list: list):
|
|
270 |
seen_chunks.add(chunk)
|
271 |
return final_context
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
|
|
|
|
|
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
return final_context, citation_data_list
|
314 |
|
@@ -325,7 +351,11 @@ def query_multiple_collections(collection_names: list[str], query_list: list[str
|
|
325 |
dict: Combined results from all collections, grouped by collection.
|
326 |
"""
|
327 |
# Define embedder inside the function
|
328 |
-
|
|
|
|
|
|
|
|
|
329 |
retriever = Retriever()
|
330 |
|
331 |
def query_single_collection(collection_name: str):
|
|
|
8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
9 |
import time
|
10 |
import concurrent.futures
|
11 |
+
from ..path_utils import get_path, setup_hf_cache
|
12 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
13 |
+
|
14 |
+
# 设置 Hugging Face 缓存目录
|
15 |
+
cache_dir = setup_hf_cache()
|
16 |
|
17 |
class Retriever:
|
18 |
client = None
|
|
|
227 |
return collection_name, embeddings_list, documents_list, metadata_list,title_new
|
228 |
|
229 |
def query_embeddings(collection_name: str, query_list: list):
|
230 |
+
try:
|
231 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
232 |
+
except Exception as e:
|
233 |
+
print(f"Error initializing embedder: {e}")
|
234 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
235 |
retriever = Retriever()
|
236 |
|
237 |
final_context = ""
|
|
|
252 |
|
253 |
# new, may be in parallel
|
254 |
def query_embeddings_new(collection_name: str, query_list: list):
|
255 |
+
try:
|
256 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
257 |
+
except Exception as e:
|
258 |
+
print(f"Error initializing embedder: {e}")
|
259 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
260 |
retriever = Retriever()
|
261 |
|
262 |
final_context = ""
|
|
|
282 |
seen_chunks.add(chunk)
|
283 |
return final_context
|
284 |
|
285 |
+
# wza
|
286 |
+
def query_embeddings_new_new(collection_name: str, query_list: list):
|
287 |
+
try:
|
288 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
289 |
+
except Exception as e:
|
290 |
+
print(f"Error initializing embedder: {e}")
|
291 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
292 |
+
retriever = Retriever()
|
293 |
|
294 |
+
final_context = "" # Stores concatenated context
|
295 |
+
citation_data_list = [] # Stores chunk content and collection name as source
|
296 |
+
seen_chunks = set() # Ensures unique chunks are added
|
297 |
+
|
298 |
+
def process_query(query_text):
|
299 |
+
# Embed the query text and retrieve relevant chunks
|
300 |
+
query_embeddings = embedder.embed_query(query_text)
|
301 |
+
query_result = retriever.query_chroma(
|
302 |
+
collection_name=collection_name,
|
303 |
+
query_embeddings=[query_embeddings],
|
304 |
+
n_results=5 # Fixed number of results
|
305 |
+
)
|
306 |
+
return query_result
|
307 |
+
|
308 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
309 |
+
future_to_query = {executor.submit(process_query, q): q for q in query_list}
|
310 |
+
for future in concurrent.futures.as_completed(future_to_query):
|
311 |
+
query_text = future_to_query[future]
|
312 |
+
try:
|
313 |
+
query_result = future.result()
|
314 |
+
except Exception as e:
|
315 |
+
print(f"Query '{query_text}' failed with exception: {e}")
|
316 |
+
continue
|
317 |
+
|
318 |
+
if "documents" not in query_result or "distances" not in query_result:
|
319 |
+
continue
|
320 |
+
if not query_result["documents"] or not query_result["distances"]:
|
321 |
+
continue
|
322 |
+
docs_list = query_result["documents"][0] if query_result["documents"] else []
|
323 |
+
dist_list = query_result["distances"][0] if query_result["distances"] else []
|
324 |
+
|
325 |
+
if len(docs_list) != len(dist_list):
|
326 |
+
continue
|
327 |
+
|
328 |
+
for chunk, distance in zip(docs_list, dist_list):
|
329 |
+
processed_chunk = chunk.strip()
|
330 |
+
if processed_chunk not in seen_chunks:
|
331 |
+
final_context += processed_chunk + "//\n"
|
332 |
+
seen_chunks.add(processed_chunk)
|
333 |
+
citation_data_list.append({
|
334 |
+
"source": collection_name,
|
335 |
+
"distance": distance,
|
336 |
+
"content": processed_chunk,
|
337 |
+
})
|
338 |
|
339 |
return final_context, citation_data_list
|
340 |
|
|
|
351 |
dict: Combined results from all collections, grouped by collection.
|
352 |
"""
|
353 |
# Define embedder inside the function
|
354 |
+
try:
|
355 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
356 |
+
except Exception as e:
|
357 |
+
print(f"Error initializing embedder: {e}")
|
358 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
359 |
retriever = Retriever()
|
360 |
|
361 |
def query_single_collection(collection_name: str):
|
src/demo/survey_generation_pipeline/category_and_tsne.py
CHANGED
@@ -1,15 +1,22 @@
|
|
1 |
from sklearn.metrics import silhouette_score
|
2 |
|
3 |
import numpy as np
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
import seaborn as sns
|
6 |
-
import
|
7 |
from sklearn.manifold import TSNE
|
8 |
from sklearn.cluster import AgglomerativeClustering
|
9 |
-
import
|
10 |
-
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
plt.switch_backend('agg')
|
15 |
device = 0
|
@@ -30,6 +37,9 @@ import matplotlib.pyplot as plt
|
|
30 |
from sklearn.manifold import TSNE
|
31 |
import seaborn as sns
|
32 |
|
|
|
|
|
|
|
33 |
class DimensionalityReduction:
|
34 |
def fit(self, X):
|
35 |
return self
|
@@ -39,7 +49,11 @@ class DimensionalityReduction:
|
|
39 |
|
40 |
class ClusteringWithTopic:
|
41 |
def __init__(self, df, n_topics=3):
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
# umap_model = DimensionalityReduction()
|
44 |
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', init = 'pca')
|
45 |
hdbscan_model = AgglomerativeClustering(n_clusters=n_topics)
|
@@ -76,7 +90,11 @@ class ClusteringWithTopic:
|
|
76 |
初始化 ClusteringWithTopic,接受一个 n_topics_list,其中包含多个聚类数目,
|
77 |
选取 silhouette_score 最高的结果。
|
78 |
"""
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
self.embeddings = embedding_model.encode(df, show_progress_bar=True)
|
81 |
|
82 |
self.df = df
|
@@ -92,7 +110,7 @@ class ClusteringWithTopic:
|
|
92 |
# 用于存储不同聚类数目的结果
|
93 |
self.best_n_topics = None
|
94 |
self.best_labels = None
|
95 |
-
self.best_score = -1
|
96 |
# def fit_and_get_labels(self, X):
|
97 |
# topics, probs = self.topic_model.fit_transform(self.df, self.embeddings)
|
98 |
# return topics
|
|
|
1 |
from sklearn.metrics import silhouette_score
|
2 |
|
3 |
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
import matplotlib.pyplot as plt
|
6 |
import seaborn as sns
|
7 |
+
import json
|
8 |
from sklearn.manifold import TSNE
|
9 |
from sklearn.cluster import AgglomerativeClustering
|
10 |
+
import os
|
11 |
+
import tempfile
|
12 |
|
13 |
+
from sentence_transformers import SentenceTransformer
|
14 |
+
from bertopic import BERTopic
|
15 |
+
from bertopic.representation import KeyBERTInspired
|
16 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
17 |
+
from bertopic.vectorizers import ClassTfidfTransformer
|
18 |
+
from umap import UMAP
|
19 |
+
from ..path_utils import get_path, setup_hf_cache
|
20 |
|
21 |
plt.switch_backend('agg')
|
22 |
device = 0
|
|
|
37 |
from sklearn.manifold import TSNE
|
38 |
import seaborn as sns
|
39 |
|
40 |
+
# 设置 Hugging Face 缓存目录
|
41 |
+
cache_dir = setup_hf_cache()
|
42 |
+
|
43 |
class DimensionalityReduction:
|
44 |
def fit(self, X):
|
45 |
return self
|
|
|
49 |
|
50 |
class ClusteringWithTopic:
|
51 |
def __init__(self, df, n_topics=3):
|
52 |
+
try:
|
53 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True, cache_folder=cache_dir)
|
54 |
+
except Exception as e:
|
55 |
+
print(f"Error initializing SentenceTransformer: {e}")
|
56 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
|
57 |
# umap_model = DimensionalityReduction()
|
58 |
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', init = 'pca')
|
59 |
hdbscan_model = AgglomerativeClustering(n_clusters=n_topics)
|
|
|
90 |
初始化 ClusteringWithTopic,接受一个 n_topics_list,其中包含多个聚类数目,
|
91 |
选取 silhouette_score 最高的结果。
|
92 |
"""
|
93 |
+
try:
|
94 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True, cache_folder=cache_dir)
|
95 |
+
except Exception as e:
|
96 |
+
print(f"Error initializing SentenceTransformer: {e}")
|
97 |
+
embedding_model = SentenceTransformer("nomic-ai/nomic-embed-text-v1", trust_remote_code=True)
|
98 |
self.embeddings = embedding_model.encode(df, show_progress_bar=True)
|
99 |
|
100 |
self.df = df
|
|
|
110 |
# 用于存储不同聚类数目的结果
|
111 |
self.best_n_topics = None
|
112 |
self.best_labels = None
|
113 |
+
self.best_score = -1
|
114 |
# def fit_and_get_labels(self, X):
|
115 |
# topics, probs = self.topic_model.fit_transform(self.df, self.embeddings)
|
116 |
# return topics
|
src/demo/survey_generation_pipeline/main.py
CHANGED
@@ -27,6 +27,10 @@ import os
|
|
27 |
from markdown_pdf import MarkdownPdf, Section # Assuming you are using markdown_pdf
|
28 |
from typing import Any
|
29 |
import xml.etree.ElementTree as ET
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def clean_str(input_str):
|
32 |
input_str = str(input_str).strip().lower()
|
@@ -286,15 +290,19 @@ class ASG_system:
|
|
286 |
self.pipeline = None
|
287 |
|
288 |
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
|
|
298 |
# self.pipeline.model.load_adapter(peft_model_id = "technicolor/llama3.1_8b_outline_generation", adapter_name="outline")
|
299 |
# self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_abstract_generation", adapter_name="abstract")
|
300 |
# self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_conclusion_generation", adapter_name="conclusion")
|
|
|
27 |
from markdown_pdf import MarkdownPdf, Section # Assuming you are using markdown_pdf
|
28 |
from typing import Any
|
29 |
import xml.etree.ElementTree as ET
|
30 |
+
from .path_utils import get_path, setup_hf_cache
|
31 |
+
|
32 |
+
# 设置 Hugging Face 缓存目录
|
33 |
+
cache_dir = setup_hf_cache()
|
34 |
|
35 |
def clean_str(input_str):
|
36 |
input_str = str(input_str).strip().lower()
|
|
|
290 |
self.pipeline = None
|
291 |
|
292 |
|
293 |
+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
294 |
+
try:
|
295 |
+
self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
296 |
+
except Exception as e:
|
297 |
+
print(f"Error initializing embedder: {e}")
|
298 |
+
self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
299 |
+
self.pipeline = transformers.pipeline(
|
300 |
+
"text-generation",
|
301 |
+
model=model_id,
|
302 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
303 |
+
token = os.getenv('HF_API_KEY'),
|
304 |
+
device_map="auto",
|
305 |
+
)
|
306 |
# self.pipeline.model.load_adapter(peft_model_id = "technicolor/llama3.1_8b_outline_generation", adapter_name="outline")
|
307 |
# self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_abstract_generation", adapter_name="abstract")
|
308 |
# self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_conclusion_generation", adapter_name="conclusion")
|
src/demo/survey_generator_api.py
CHANGED
@@ -9,6 +9,10 @@ import numpy as np
|
|
9 |
from numpy.linalg import norm
|
10 |
import openai
|
11 |
from .asg_retriever import Retriever
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def getQwenClient():
|
14 |
# openai_api_key = os.environ.get("OPENAI_API_KEY")
|
@@ -506,7 +510,7 @@ Survey Paper Content for "{section_title}":
|
|
506 |
response = generateResponse(client, formatted_prompt).strip()
|
507 |
sentences = re.split(r'(?<=[.!?])\s+', response.strip())
|
508 |
|
509 |
-
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
510 |
sentence_embeddings = embedder.embed_documents(sentences)
|
511 |
chunk_texts = [c["content"] for c in citation_data_list]
|
512 |
chunk_sources = [c["source"] for c in citation_data_list]
|
@@ -627,7 +631,7 @@ Survey Paper Content for "{section_title}":
|
|
627 |
para_index_map.append(p_idx)
|
628 |
|
629 |
# -- 3. 对所有句子进行向量化嵌入(保持逻辑:一次性处理全文) ---
|
630 |
-
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
631 |
sentence_embeddings = embedder.embed_documents(all_sentences)
|
632 |
|
633 |
# -- 4. 对 citation_data_list 做向量化嵌入 ---
|
@@ -763,25 +767,17 @@ def query_embedding_for_title(
|
|
763 |
n_results: int = 1,
|
764 |
embedder: HuggingFaceEmbeddings = None
|
765 |
):
|
766 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
767 |
retriever = Retriever()
|
768 |
-
|
769 |
-
|
770 |
-
query_result
|
771 |
-
collection_name=collection_name,
|
772 |
-
query_embeddings=[title_embedding],
|
773 |
-
n_results=n_results
|
774 |
-
)
|
775 |
-
# old
|
776 |
-
# query_result_chunks = query_result["documents"][0]
|
777 |
-
# for chunk in query_result_chunks:
|
778 |
-
# final_context += chunk.strip() + "//\n"
|
779 |
-
|
780 |
-
# 2025
|
781 |
-
if "documents" in query_result and len(query_result["documents"]) > 0:
|
782 |
-
for chunk in query_result["documents"][0]:
|
783 |
-
final_context += chunk.strip() + "//\n"
|
784 |
-
return final_context
|
785 |
|
786 |
# old
|
787 |
def generate_context_list(outline, collection_list):
|
@@ -812,32 +808,32 @@ def generate_context_list(outline, collection_list):
|
|
812 |
|
813 |
# 2025
|
814 |
def generate_context_list(outline, collection_list):
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
context_list_final = []
|
821 |
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
|
|
841 |
|
842 |
# 1.8 输入introduction 输出带引用 (collection name) 的introduction
|
843 |
def introduction_with_citations(
|
@@ -847,110 +843,65 @@ def introduction_with_citations(
|
|
847 |
dynamic_threshold: bool = True,
|
848 |
diversity_limit: int = 3
|
849 |
) -> str:
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
:
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
#
|
861 |
-
|
862 |
-
if not paragraphs:
|
863 |
-
return intro_text
|
864 |
-
|
865 |
-
# 2. 逐段落拆分句子,记录每句所属段落编号
|
866 |
-
all_sentences = []
|
867 |
-
para_index_map = []
|
868 |
-
for p_idx, para in enumerate(paragraphs):
|
869 |
-
if not para.strip():
|
870 |
-
# 空段落,直接跳过切句,保持段落分隔
|
871 |
-
continue
|
872 |
-
# 用正则在段落内部按 .!? 分句
|
873 |
-
sentences_in_para = re.split(r'(?<=[.!?])\s+', para)
|
874 |
-
for sent in sentences_in_para:
|
875 |
-
if sent:
|
876 |
-
all_sentences.append(sent)
|
877 |
-
para_index_map.append(p_idx)
|
878 |
-
|
879 |
-
# 如果拆不出任何句子,直接返回
|
880 |
-
if not all_sentences:
|
881 |
-
return intro_text
|
882 |
-
|
883 |
-
# 3. 对所有句子进行 Embedding
|
884 |
-
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
885 |
-
sentence_embeddings = embedder.embed_documents(all_sentences)
|
886 |
-
|
887 |
-
# 4. 对 citation_data_list 里每段文献块进行向量化
|
888 |
chunk_texts = [c["content"] for c in citation_data_list]
|
889 |
chunk_sources = [c["source"] for c in citation_data_list]
|
890 |
chunk_embeddings = embedder.embed_documents(chunk_texts)
|
891 |
-
|
|
|
892 |
def cosine_sim(a, b):
|
893 |
return np.dot(a, b) / (norm(a) * norm(b) + 1e-9)
|
894 |
-
|
895 |
-
#
|
896 |
sim_matrix = []
|
897 |
for s_emb in sentence_embeddings:
|
898 |
row = [cosine_sim(s_emb, c_emb) for c_emb in chunk_embeddings]
|
899 |
sim_matrix.append(row)
|
900 |
sim_matrix = np.array(sim_matrix)
|
901 |
-
|
902 |
-
#
|
903 |
all_sims = sim_matrix.flatten()
|
904 |
mean_sim = np.mean(all_sims)
|
905 |
-
std_sim
|
906 |
k = 0.5
|
907 |
threshold = max(base_threshold, mean_sim + k * std_sim) if dynamic_threshold else base_threshold
|
908 |
-
|
909 |
-
#
|
910 |
candidates = []
|
911 |
-
for i in
|
912 |
-
for j in
|
913 |
-
if
|
914 |
-
candidates.append((i, j,
|
915 |
-
|
916 |
-
#
|
|
|
917 |
candidates.sort(key=lambda x: x[2], reverse=True)
|
918 |
-
|
919 |
-
# 记录:句子 -> 已分配的 source;并限制每个 source 最多引用次数
|
920 |
-
source_count = {src: 0 for src in chunk_sources}
|
921 |
assigned = {}
|
922 |
-
|
923 |
-
for (sent_id, chk_id,
|
924 |
if sent_id not in assigned:
|
925 |
src = chunk_sources[chk_id]
|
926 |
if source_count[src] < diversity_limit:
|
927 |
assigned[sent_id] = src
|
928 |
source_count[src] += 1
|
929 |
-
|
930 |
-
#
|
931 |
updated_sentences = []
|
932 |
-
for i, sentence in enumerate(
|
933 |
if i in assigned:
|
934 |
updated_sentences.append(sentence + f" [{assigned[i]}]")
|
935 |
else:
|
936 |
updated_sentences.append(sentence)
|
937 |
-
|
938 |
-
|
939 |
-
updated_paras = [""] * len(paragraphs)
|
940 |
-
para_sentences_map = [[] for _ in range(len(paragraphs))]
|
941 |
-
|
942 |
-
for s_idx, sent in enumerate(updated_sentences):
|
943 |
-
p_idx = para_index_map[s_idx]
|
944 |
-
para_sentences_map[p_idx].append(sent)
|
945 |
-
|
946 |
-
for i in range(len(paragraphs)):
|
947 |
-
if not paragraphs[i].strip():
|
948 |
-
# 保持空段落不动
|
949 |
-
updated_paras[i] = paragraphs[i]
|
950 |
-
else:
|
951 |
-
# 同段落内的句子用空格拼起来
|
952 |
-
updated_paras[i] = " ".join(para_sentences_map[i])
|
953 |
-
|
954 |
-
# 11. 用原先换行分隔符拼回
|
955 |
-
updated_intro = "\n\n".join(updated_paras)
|
956 |
-
return updated_intro
|
|
|
9 |
from numpy.linalg import norm
|
10 |
import openai
|
11 |
from .asg_retriever import Retriever
|
12 |
+
from .path_utils import get_path, setup_hf_cache
|
13 |
+
|
14 |
+
# 设置 Hugging Face 缓存目录
|
15 |
+
cache_dir = setup_hf_cache()
|
16 |
|
17 |
def getQwenClient():
|
18 |
# openai_api_key = os.environ.get("OPENAI_API_KEY")
|
|
|
510 |
response = generateResponse(client, formatted_prompt).strip()
|
511 |
sentences = re.split(r'(?<=[.!?])\s+', response.strip())
|
512 |
|
513 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
514 |
sentence_embeddings = embedder.embed_documents(sentences)
|
515 |
chunk_texts = [c["content"] for c in citation_data_list]
|
516 |
chunk_sources = [c["source"] for c in citation_data_list]
|
|
|
631 |
para_index_map.append(p_idx)
|
632 |
|
633 |
# -- 3. 对所有句子进行向量化嵌入(保持逻辑:一次性处理全文) ---
|
634 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
635 |
sentence_embeddings = embedder.embed_documents(all_sentences)
|
636 |
|
637 |
# -- 4. 对 citation_data_list 做向量化嵌入 ---
|
|
|
767 |
n_results: int = 1,
|
768 |
embedder: HuggingFaceEmbeddings = None
|
769 |
):
|
770 |
+
if embedder is None:
|
771 |
+
try:
|
772 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
773 |
+
except Exception as e:
|
774 |
+
print(f"Error initializing embedder: {e}")
|
775 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
776 |
+
|
777 |
retriever = Retriever()
|
778 |
+
query_embeddings = embedder.embed_query(title)
|
779 |
+
query_result = retriever.query_chroma(collection_name=collection_name, query_embeddings=[query_embeddings], n_results=n_results)
|
780 |
+
return query_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
|
782 |
# old
|
783 |
def generate_context_list(outline, collection_list):
|
|
|
808 |
|
809 |
# 2025
|
810 |
def generate_context_list(outline, collection_list):
|
811 |
+
try:
|
812 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
813 |
+
except Exception as e:
|
814 |
+
print(f"Error initializing embedder: {e}")
|
815 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
816 |
|
817 |
+
retriever = Retriever()
|
818 |
+
context_list = []
|
819 |
+
|
820 |
+
for section_title in outline:
|
821 |
+
query_embeddings = embedder.embed_query(section_title)
|
822 |
+
final_context = ""
|
823 |
+
seen_chunks = set()
|
824 |
|
825 |
+
for collection_name in collection_list:
|
826 |
+
query_result = retriever.query_chroma(collection_name=collection_name, query_embeddings=[query_embeddings], n_results=2)
|
827 |
+
query_result_chunks = query_result["documents"][0]
|
828 |
+
|
829 |
+
for chunk in query_result_chunks:
|
830 |
+
if chunk not in seen_chunks:
|
831 |
+
final_context += chunk.strip() + "//\n"
|
832 |
+
seen_chunks.add(chunk)
|
833 |
+
|
834 |
+
context_list.append(final_context)
|
835 |
+
|
836 |
+
return context_list
|
837 |
|
838 |
# 1.8 输入introduction 输出带引用 (collection name) 的introduction
|
839 |
def introduction_with_citations(
|
|
|
843 |
dynamic_threshold: bool = True,
|
844 |
diversity_limit: int = 3
|
845 |
) -> str:
|
846 |
+
# 将介绍文本按句子分割
|
847 |
+
sentences = re.split(r'(?<=[.!?])\s+', intro_text.strip())
|
848 |
+
|
849 |
+
# 初始化 embedder
|
850 |
+
try:
|
851 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir)
|
852 |
+
except Exception as e:
|
853 |
+
print(f"Error initializing embedder: {e}")
|
854 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
855 |
+
|
856 |
+
# 对句子和引用数据进行向量化
|
857 |
+
sentence_embeddings = embedder.embed_documents(sentences)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
858 |
chunk_texts = [c["content"] for c in citation_data_list]
|
859 |
chunk_sources = [c["source"] for c in citation_data_list]
|
860 |
chunk_embeddings = embedder.embed_documents(chunk_texts)
|
861 |
+
|
862 |
+
# 计算余弦相似度
|
863 |
def cosine_sim(a, b):
|
864 |
return np.dot(a, b) / (norm(a) * norm(b) + 1e-9)
|
865 |
+
|
866 |
+
# 构建相似度矩阵
|
867 |
sim_matrix = []
|
868 |
for s_emb in sentence_embeddings:
|
869 |
row = [cosine_sim(s_emb, c_emb) for c_emb in chunk_embeddings]
|
870 |
sim_matrix.append(row)
|
871 |
sim_matrix = np.array(sim_matrix)
|
872 |
+
|
873 |
+
# 计算动态阈值
|
874 |
all_sims = sim_matrix.flatten()
|
875 |
mean_sim = np.mean(all_sims)
|
876 |
+
std_sim = np.std(all_sims)
|
877 |
k = 0.5
|
878 |
threshold = max(base_threshold, mean_sim + k * std_sim) if dynamic_threshold else base_threshold
|
879 |
+
|
880 |
+
# 找出候选引用
|
881 |
candidates = []
|
882 |
+
for i, sent in enumerate(sentences):
|
883 |
+
for j, sim in enumerate(sim_matrix[i]):
|
884 |
+
if sim >= threshold:
|
885 |
+
candidates.append((i, j, sim))
|
886 |
+
|
887 |
+
# 按相似度排序并分配引用
|
888 |
+
source_count = {s: 0 for s in chunk_sources}
|
889 |
candidates.sort(key=lambda x: x[2], reverse=True)
|
|
|
|
|
|
|
890 |
assigned = {}
|
891 |
+
|
892 |
+
for (sent_id, chk_id, sim) in candidates:
|
893 |
if sent_id not in assigned:
|
894 |
src = chunk_sources[chk_id]
|
895 |
if source_count[src] < diversity_limit:
|
896 |
assigned[sent_id] = src
|
897 |
source_count[src] += 1
|
898 |
+
|
899 |
+
# 更新句子
|
900 |
updated_sentences = []
|
901 |
+
for i, sentence in enumerate(sentences):
|
902 |
if i in assigned:
|
903 |
updated_sentences.append(sentence + f" [{assigned[i]}]")
|
904 |
else:
|
905 |
updated_sentences.append(sentence)
|
906 |
+
|
907 |
+
return " ".join(updated_sentences)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/demo/views.py
CHANGED
@@ -44,7 +44,7 @@ from dotenv import load_dotenv
|
|
44 |
from pathlib import Path
|
45 |
from markdown_pdf import MarkdownPdf, Section
|
46 |
import tempfile
|
47 |
-
from .path_utils import get_path
|
48 |
|
49 |
dotenv_path = os.path.join(os.path.dirname(__file__), ".env")
|
50 |
load_dotenv()
|
@@ -59,6 +59,9 @@ load_dotenv()
|
|
59 |
# print(f"OPENAI_API_KEY: {openai_api_key}")
|
60 |
# print(f"OPENAI_API_BASE: {openai_api_base}")
|
61 |
|
|
|
|
|
|
|
62 |
# 获取路径配置
|
63 |
paths_config = get_path('pdf') # 使用 get_path 函数获取路径配置
|
64 |
DATA_PATH = get_path('pdf')
|
@@ -144,8 +147,15 @@ Global_cluster_names = []
|
|
144 |
Global_citation_data = []
|
145 |
Global_cluster_num = 4
|
146 |
|
147 |
-
|
148 |
-
embedder = HuggingFaceEmbeddings(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
from demo.category_and_tsne import clustering
|
151 |
|
|
|
44 |
from pathlib import Path
|
45 |
from markdown_pdf import MarkdownPdf, Section
|
46 |
import tempfile
|
47 |
+
from .path_utils import get_path, setup_hf_cache
|
48 |
|
49 |
dotenv_path = os.path.join(os.path.dirname(__file__), ".env")
|
50 |
load_dotenv()
|
|
|
59 |
# print(f"OPENAI_API_KEY: {openai_api_key}")
|
60 |
# print(f"OPENAI_API_BASE: {openai_api_base}")
|
61 |
|
62 |
+
# 设置 Hugging Face 缓存目录
|
63 |
+
cache_dir = setup_hf_cache()
|
64 |
+
|
65 |
# 获取路径配置
|
66 |
paths_config = get_path('pdf') # 使用 get_path 函数获取路径配置
|
67 |
DATA_PATH = get_path('pdf')
|
|
|
147 |
Global_citation_data = []
|
148 |
Global_cluster_num = 4
|
149 |
|
150 |
+
try:
|
151 |
+
embedder = HuggingFaceEmbeddings(
|
152 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
153 |
+
cache_folder=cache_dir
|
154 |
+
)
|
155 |
+
except Exception as e:
|
156 |
+
print(f"Error initializing embedder: {e}")
|
157 |
+
# 如果初始化失败,尝试使用默认设置
|
158 |
+
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
159 |
|
160 |
from demo.category_and_tsne import clustering
|
161 |
|