Spaces:
Runtime error
Runtime error
Upload 16 files
Browse files- README.md +5 -6
- ai.py +199 -0
- api.py +117 -0
- app.py +26 -0
- config.json +14 -0
- config.py +36 -0
- console.py +127 -0
- contents.py +81 -0
- index/bkp/dd771cb6c4718ace4c4c596f4792cfdd.bin +3 -0
- index/bkp/dd771cb6c4718ace4c4c596f4792cfdd.csv +0 -0
- index/dd771cb6c4718ace4c4c596f4792cfdd.bin +3 -0
- index/dd771cb6c4718ace4c4c596f4792cfdd.csv +0 -0
- main.py +26 -0
- requirements.txt +79 -0
- storage.py +171 -0
- webui.py +71 -0
README.md
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.32.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
python_version: 3.10.6
|
3 |
+
title: Kanunasor
|
4 |
+
emoji: 🏆
|
5 |
+
colorFrom: red
|
6 |
+
colorTo: blue
|
7 |
sdk: gradio
|
8 |
sdk_version: 3.32.0
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
---
|
|
|
|
ai.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import openai
|
3 |
+
import tiktoken
|
4 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
+
|
7 |
+
from config import Config
|
8 |
+
|
9 |
+
|
10 |
+
class AI:
|
11 |
+
"""The AI class."""
|
12 |
+
|
13 |
+
def __init__(self, cfg: Config):
|
14 |
+
openai.api_key = cfg.open_ai_key
|
15 |
+
openai.proxy = cfg.open_ai_proxy
|
16 |
+
self._chat_model = cfg.open_ai_chat_model
|
17 |
+
self._use_stream = cfg.use_stream
|
18 |
+
self._encoding = tiktoken.encoding_for_model('gpt-3.5-turbo')
|
19 |
+
self._language = cfg.language
|
20 |
+
self._temperature = cfg.temperature
|
21 |
+
|
22 |
+
def _chat_stream(self, messages: list[dict], use_stream: bool = None) -> str:
|
23 |
+
use_stream = use_stream if use_stream is not None else self._use_stream
|
24 |
+
response = openai.ChatCompletion.create(
|
25 |
+
temperature=self._temperature,
|
26 |
+
stream=use_stream,
|
27 |
+
model=self._chat_model,
|
28 |
+
messages=messages,
|
29 |
+
)
|
30 |
+
if use_stream:
|
31 |
+
data = ""
|
32 |
+
for chunk in response:
|
33 |
+
if chunk.choices[0].delta.get('content', None) is not None:
|
34 |
+
data += chunk.choices[0].delta.content
|
35 |
+
print(chunk.choices[0].delta.content, end='')
|
36 |
+
print()
|
37 |
+
return data.strip()
|
38 |
+
else:
|
39 |
+
print(response.choices[0].message.content.strip())
|
40 |
+
print(f"Total tokens used: {response.usage.total_tokens}, "
|
41 |
+
f"cost: ${response.usage.total_tokens / 1000 * 0.002}")
|
42 |
+
return response.choices[0].message.content.strip()
|
43 |
+
|
44 |
+
def _num_tokens_from_string(self, string: str) -> int:
|
45 |
+
"""Returns the number of tokens in a text string."""
|
46 |
+
num_tokens = len(self._encoding.encode(string))
|
47 |
+
return num_tokens
|
48 |
+
|
49 |
+
def completion(self, query: str, context: list[str]):
|
50 |
+
"""Create a completion."""
|
51 |
+
context = self._cut_texts(context)
|
52 |
+
print(f"Number of query fragments:{len(context)}")
|
53 |
+
|
54 |
+
text = "\n".join(f"{index}. {text}" for index, text in enumerate(context))
|
55 |
+
result = self._chat_stream([
|
56 |
+
{'role': 'system',
|
57 |
+
'content': f'You are a helpful AI article assistant. '
|
58 |
+
f'The following are the relevant article content fragments found from the article. '
|
59 |
+
f'The relevance is sorted from high to low. '
|
60 |
+
f'You can only answer according to the following content:\n```\n{text}\n```\n'
|
61 |
+
f'You need to carefully consider your answer to ensure that it is based on the context. '
|
62 |
+
f'If the context does not mention the content or it is uncertain whether it is correct, '
|
63 |
+
f'please answer "Bu bilgiye tam olarak hakim değilim, lütfen uzmanlarımıza danışın. Başka bir soru sorabilirsiniz."'
|
64 |
+
f'You must use {self._language} to respond.'},
|
65 |
+
{'role': 'user', 'content': query},
|
66 |
+
])
|
67 |
+
return result
|
68 |
+
|
69 |
+
def _cut_texts(self, context):
|
70 |
+
maximum = 4096 - 1024
|
71 |
+
for index, text in enumerate(context):
|
72 |
+
maximum -= self._num_tokens_from_string(text)
|
73 |
+
if maximum < 0:
|
74 |
+
context = context[:index + 1]
|
75 |
+
print(f"Exceeded maximum length, cut the first {index + 1} fragments")
|
76 |
+
break
|
77 |
+
return context
|
78 |
+
|
79 |
+
def get_keywords(self, query: str) -> str:
|
80 |
+
"""Get keywords from the query."""
|
81 |
+
result = self._chat_stream([
|
82 |
+
{'role': 'user',
|
83 |
+
'content': f'You need to extract keywords from the statement or question and '
|
84 |
+
f'return a series of keywords separated by commas.\ncontent: {query}\nkeywords: '},
|
85 |
+
], use_stream=False)
|
86 |
+
return result
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def create_embedding(text: str) -> (str, list[float]):
|
90 |
+
"""Create an embedding for the provided text."""
|
91 |
+
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=text)
|
92 |
+
return text, embedding.data[0].embedding
|
93 |
+
|
94 |
+
def create_embeddings(self, texts: list[str]) -> (list[tuple[str, list[float]]], int):
|
95 |
+
"""Create embeddings for the provided input."""
|
96 |
+
result = []
|
97 |
+
query_len = 0
|
98 |
+
start_index = 0
|
99 |
+
tokens = 0
|
100 |
+
|
101 |
+
def get_embedding(input_slice: list[str]):
|
102 |
+
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice)
|
103 |
+
return [(txt, data.embedding) for txt, data in
|
104 |
+
zip(input_slice, embedding.data)], embedding.usage.total_tokens
|
105 |
+
|
106 |
+
for index, text in enumerate(texts):
|
107 |
+
query_len += self._num_tokens_from_string(text)
|
108 |
+
if query_len > 8192 - 1024:
|
109 |
+
ebd, tk = get_embedding(texts[start_index:index + 1])
|
110 |
+
print(f"Query fragments used tokens: {tk}, cost: ${tk / 1000 * 0.0004}")
|
111 |
+
query_len = 0
|
112 |
+
start_index = index + 1
|
113 |
+
tokens += tk
|
114 |
+
result.extend(ebd)
|
115 |
+
|
116 |
+
if query_len > 0:
|
117 |
+
ebd, tk = get_embedding(texts[start_index:])
|
118 |
+
print(f"Query fragments used tokens: {tk}, cost: ${tk / 1000 * 0.0004}")
|
119 |
+
tokens += tk
|
120 |
+
result.extend(ebd)
|
121 |
+
return result, tokens
|
122 |
+
|
123 |
+
def generate_summary(self, embeddings, num_candidates=3, use_sif=False):
|
124 |
+
"""Generate a summary for the provided embeddings."""
|
125 |
+
avg_func = self._calc_paragraph_avg_embedding_with_sif if use_sif else self._calc_avg_embedding
|
126 |
+
avg_embedding = np.array(avg_func(embeddings))
|
127 |
+
|
128 |
+
paragraphs = [e[0] for e in embeddings]
|
129 |
+
embeddings = np.array([e[1] for e in embeddings])
|
130 |
+
# 计算每个段落与整个文本的相似度分数
|
131 |
+
# Calculate the similarity score between each paragraph and the entire text.
|
132 |
+
similarity_scores = cosine_similarity(embeddings, avg_embedding.reshape(1, -1)).flatten()
|
133 |
+
|
134 |
+
# 选择具有最高相似度分数的段落作为摘要的候选段落
|
135 |
+
# Select the paragraph with the highest similarity score as the candidate paragraph for the summary.
|
136 |
+
candidate_indices = np.argsort(similarity_scores)[::-1][:num_candidates]
|
137 |
+
candidate_paragraphs = [f"paragraph {i}: {paragraphs[i]}" for i in candidate_indices]
|
138 |
+
|
139 |
+
print("Calculation completed, start generating summary")
|
140 |
+
|
141 |
+
candidate_paragraphs = self._cut_texts(candidate_paragraphs)
|
142 |
+
|
143 |
+
text = "\n".join(f"{index}. {text}" for index, text in enumerate(candidate_paragraphs))
|
144 |
+
result = self._chat_stream([
|
145 |
+
{'role': 'system',
|
146 |
+
'content': f'As a helpful AI article assistant, '
|
147 |
+
f'I have retrieved the following relevant text fragments from the article, '
|
148 |
+
f'sorted by relevance from high to low. '
|
149 |
+
f'You need to summarize the entire article from these fragments, '
|
150 |
+
f'and present the final result in {self._language}:\n\n{text}\n\n{self._language} summary:'},
|
151 |
+
])
|
152 |
+
return result
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def _calc_avg_embedding(embeddings) -> list[float]:
|
156 |
+
# Calculate the average embedding for the entire text.
|
157 |
+
avg_embedding = np.zeros(len(embeddings[0][1]))
|
158 |
+
for emb in embeddings:
|
159 |
+
avg_embedding += np.array(emb[1])
|
160 |
+
avg_embedding /= len(embeddings)
|
161 |
+
return avg_embedding.tolist()
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def _calc_paragraph_avg_embedding_with_sif(paragraph_list) -> list[float]:
|
165 |
+
# calculate the SIF embedding for the entire text
|
166 |
+
alpha = 0.001
|
167 |
+
# calculate the total number of sentences
|
168 |
+
n_sentences = len(paragraph_list)
|
169 |
+
|
170 |
+
# calculate the total number of dimensions in the embeddings
|
171 |
+
n_dims = len(paragraph_list[0][1])
|
172 |
+
|
173 |
+
# calculate the IDF values for each word in the sentences
|
174 |
+
vectorizer = TfidfVectorizer(use_idf=True)
|
175 |
+
vectorizer.fit_transform([paragraph for paragraph, _ in paragraph_list])
|
176 |
+
idf = vectorizer.idf_
|
177 |
+
|
178 |
+
# calculate the SIF weights for each sentence
|
179 |
+
weights = np.zeros((n_sentences, n_dims))
|
180 |
+
for i, (sentence, embedding) in enumerate(paragraph_list):
|
181 |
+
sentence_words = sentence.split()
|
182 |
+
for word in sentence_words:
|
183 |
+
try:
|
184 |
+
word_index = vectorizer.vocabulary_[word]
|
185 |
+
word_idf = idf[word_index]
|
186 |
+
word_weight = alpha / (alpha + word_idf)
|
187 |
+
weights[i] += word_weight * (np.array(embedding) / np.max(embedding))
|
188 |
+
except KeyError:
|
189 |
+
pass
|
190 |
+
|
191 |
+
# calculate the weighted average of the sentence embeddings
|
192 |
+
weights_sum = np.sum(weights, axis=0)
|
193 |
+
weights_sum /= n_sentences
|
194 |
+
avg_embedding = np.zeros(n_dims)
|
195 |
+
for i, (sentence, embedding) in enumerate(paragraph_list):
|
196 |
+
avg_embedding += (np.array(embedding) / np.max(embedding)) - weights[i]
|
197 |
+
avg_embedding /= n_sentences
|
198 |
+
|
199 |
+
return avg_embedding.tolist()
|
api.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
import uvicorn
|
5 |
+
import xxhash
|
6 |
+
from fastapi import FastAPI, UploadFile, File
|
7 |
+
from fastapi.exceptions import RequestValidationError
|
8 |
+
from pydantic import BaseModel
|
9 |
+
from starlette.exceptions import HTTPException
|
10 |
+
from starlette.requests import Request
|
11 |
+
from starlette.responses import JSONResponse
|
12 |
+
|
13 |
+
from ai import AI
|
14 |
+
from config import Config
|
15 |
+
from contents import web_crawler_newspaper, extract_text_from_txt, extract_text_from_docx, \
|
16 |
+
extract_text_from_pdf
|
17 |
+
from storage import Storage
|
18 |
+
|
19 |
+
|
20 |
+
def api(cfg: Config):
|
21 |
+
"""Run the API."""
|
22 |
+
|
23 |
+
cfg.use_stream = False
|
24 |
+
ai = AI(cfg)
|
25 |
+
|
26 |
+
app = FastAPI()
|
27 |
+
|
28 |
+
class CrawlerUrlRequest(BaseModel):
|
29 |
+
url: str
|
30 |
+
|
31 |
+
@app.post("/crawler_url")
|
32 |
+
async def crawler_url(req: CrawlerUrlRequest):
|
33 |
+
"""Crawler the URL."""
|
34 |
+
contents, lang = web_crawler_newspaper(req.url)
|
35 |
+
hash_id = xxhash.xxh3_128_hexdigest('\n'.join(contents))
|
36 |
+
tokens = _save_to_storage(contents, hash_id)
|
37 |
+
return {"code": 0, "msg": "ok", "data": {"uri": f"{hash_id}/{lang}", "tokens": tokens}}
|
38 |
+
|
39 |
+
def _save_to_storage(contents, hash_id):
|
40 |
+
storage = Storage.create_storage(cfg)
|
41 |
+
if storage.been_indexed(hash_id):
|
42 |
+
return 0
|
43 |
+
else:
|
44 |
+
embeddings, tokens = ai.create_embeddings(contents)
|
45 |
+
storage.add_all(embeddings, hash_id)
|
46 |
+
return tokens
|
47 |
+
|
48 |
+
@app.post("/upload_file")
|
49 |
+
async def create_upload_file(file: UploadFile = File(...)):
|
50 |
+
"""Upload file."""
|
51 |
+
# save file to disk
|
52 |
+
file_name = file.filename
|
53 |
+
os.makedirs('./upload', exist_ok=True)
|
54 |
+
upload_path = os.path.join('./upload', file_name)
|
55 |
+
with open(upload_path, "wb") as buffer:
|
56 |
+
shutil.copyfileobj(file.file, buffer)
|
57 |
+
if file_name.endswith('.pdf'):
|
58 |
+
contents, lang = extract_text_from_pdf(upload_path)
|
59 |
+
elif file_name.endswith('.txt'):
|
60 |
+
contents, lang = extract_text_from_txt(upload_path)
|
61 |
+
elif file_name.endswith('.docx'):
|
62 |
+
contents, lang = extract_text_from_docx(upload_path)
|
63 |
+
else:
|
64 |
+
return {"code": 1, "msg": "not support", "data": {}}
|
65 |
+
hash_id = xxhash.xxh3_128_hexdigest('\n'.join(contents))
|
66 |
+
tokens = _save_to_storage(contents, hash_id)
|
67 |
+
os.remove(upload_path)
|
68 |
+
return {"code": 0, "msg": "ok", "data": {"uri": f"{hash_id}/{lang}", "tokens": tokens}}
|
69 |
+
|
70 |
+
@app.get("/summary")
|
71 |
+
async def summary(uri: str):
|
72 |
+
"""Generate summary."""
|
73 |
+
hash_id, lang = uri.split('/')
|
74 |
+
storage = Storage.create_storage(cfg)
|
75 |
+
if not storage or not lang:
|
76 |
+
return {"code": 1, "msg": "not found", "data": {}}
|
77 |
+
s = ai.generate_summary(storage.get_all_embeddings(hash_id), num_candidates=100,
|
78 |
+
use_sif=lang not in ['zh', 'ja', 'ko', 'hi', 'ar', 'fa'])
|
79 |
+
return {"code": 0, "msg": "ok", "data": {"summary": s}}
|
80 |
+
|
81 |
+
class AnswerRequest(BaseModel):
|
82 |
+
uri: str
|
83 |
+
query: str
|
84 |
+
|
85 |
+
@app.get("/answer")
|
86 |
+
async def answer(req: AnswerRequest):
|
87 |
+
"""Query."""
|
88 |
+
hash_id, lang = req.uri.split('/')
|
89 |
+
storage = Storage.create_storage(cfg)
|
90 |
+
if not storage or not lang:
|
91 |
+
return {"code": 1, "msg": "not found", "data": {}}
|
92 |
+
keywords = ai.get_keywords(req.query)
|
93 |
+
_, embedding = ai.create_embedding(keywords)
|
94 |
+
texts = storage.get_texts(embedding, hash_id)
|
95 |
+
s = ai.completion(req.query, texts)
|
96 |
+
return {"code": 0, "msg": "ok", "data": {"answer": s}}
|
97 |
+
|
98 |
+
@app.exception_handler(RequestValidationError)
|
99 |
+
async def validate_error_handler(request: Request, exc: RequestValidationError):
|
100 |
+
"""Error handler."""
|
101 |
+
print("validate_error_handler: ", request.url, exc)
|
102 |
+
return JSONResponse(
|
103 |
+
status_code=400,
|
104 |
+
content={"code": 1, "msg": str(exc.errors()), "data": {}},
|
105 |
+
)
|
106 |
+
|
107 |
+
@app.exception_handler(HTTPException)
|
108 |
+
async def http_error_handler(request: Request, exc):
|
109 |
+
"""Error handler."""
|
110 |
+
print("http error_handler: ", request.url, exc)
|
111 |
+
return JSONResponse(
|
112 |
+
status_code=400,
|
113 |
+
content={"code": 1, "msg": exc.detail, "data": {}},
|
114 |
+
)
|
115 |
+
|
116 |
+
# run the API
|
117 |
+
uvicorn.run(app, host=cfg.api_host, port=cfg.api_port)
|
app.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from api import api
|
5 |
+
from config import Config
|
6 |
+
from console import console
|
7 |
+
from webui import webui
|
8 |
+
|
9 |
+
|
10 |
+
def run():
|
11 |
+
"""Run the program."""
|
12 |
+
cfg = Config()
|
13 |
+
|
14 |
+
mode = cfg.mode
|
15 |
+
if mode == 'console':
|
16 |
+
console(cfg)
|
17 |
+
elif mode == 'api':
|
18 |
+
api(cfg)
|
19 |
+
elif mode == 'webui':
|
20 |
+
webui(cfg)
|
21 |
+
else:
|
22 |
+
raise ValueError('mode must be console or api')
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == '__main__':
|
26 |
+
run()
|
config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"temperature": 0.1,
|
3 |
+
"language": "Turkish",
|
4 |
+
"open_ai_chat_model": "gpt-3.5-turbo",
|
5 |
+
"use_stream": false,
|
6 |
+
"use_postgres": false,
|
7 |
+
"index_path": "./index",
|
8 |
+
"postgres_url": "postgresql://localhost:5432/mydb",
|
9 |
+
"mode": "webui",
|
10 |
+
"api_port": 9531,
|
11 |
+
"api_host": "localhost",
|
12 |
+
"webui_port": 7860,
|
13 |
+
"webui_host": "0.0.0.0"
|
14 |
+
}
|
config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
class Config:
|
6 |
+
def __init__(self):
|
7 |
+
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
|
8 |
+
if not os.path.exists(config_path):
|
9 |
+
raise FileNotFoundError(f'config.json not found at {config_path}, '
|
10 |
+
f'please copy config.example.json to config.json and modify it.')
|
11 |
+
with open(config_path, 'r') as f:
|
12 |
+
self.open_ai_key = os.environ['open_ai_key']
|
13 |
+
self.config = json.load(f)
|
14 |
+
self.language = self.config.get('language', 'Chinese')
|
15 |
+
self.open_ai_proxy = self.config.get('open_ai_proxy')
|
16 |
+
self.open_ai_chat_model = self.config.get('open_ai_chat_model', 'gpt-3.5-turbo')
|
17 |
+
if not self.open_ai_key:
|
18 |
+
raise ValueError('open_ai_key is not set')
|
19 |
+
self.temperature = self.config.get('temperature', 0.1)
|
20 |
+
if self.temperature < 0 or self.temperature > 1:
|
21 |
+
raise ValueError('temperature must be between 0 and 1, less is more conservative, more is more creative')
|
22 |
+
self.use_stream = self.config.get('use_stream', False)
|
23 |
+
self.use_postgres = self.config.get('use_postgres', False)
|
24 |
+
if not self.use_postgres:
|
25 |
+
self.index_path = self.config.get('index_path', './temp')
|
26 |
+
os.makedirs(self.index_path, exist_ok=True)
|
27 |
+
self.postgres_url = self.config.get('postgres_url')
|
28 |
+
if self.use_postgres and self.postgres_url is None:
|
29 |
+
raise ValueError('postgres_url is not set')
|
30 |
+
self.mode = self.config.get('mode', 'webui')
|
31 |
+
if self.mode not in ['console', 'api', 'webui']:
|
32 |
+
raise ValueError('mode must be console or api or webui')
|
33 |
+
self.api_port = self.config.get('api_port', 9531)
|
34 |
+
self.api_host = self.config.get('api_host', 'localhost')
|
35 |
+
self.webui_port = self.config.get('webui_port', 7860)
|
36 |
+
self.webui_host = self.config.get('webui_host', '0.0.0.0')
|
console.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import xxhash
|
2 |
+
|
3 |
+
from ai import AI
|
4 |
+
from config import Config
|
5 |
+
from storage import Storage
|
6 |
+
from contents import *
|
7 |
+
|
8 |
+
|
9 |
+
def console(cfg: Config):
|
10 |
+
try:
|
11 |
+
while True:
|
12 |
+
if not _console(cfg):
|
13 |
+
return
|
14 |
+
except KeyboardInterrupt:
|
15 |
+
print("exit")
|
16 |
+
|
17 |
+
|
18 |
+
def _console(cfg: Config) -> bool:
|
19 |
+
"""Run the console."""
|
20 |
+
|
21 |
+
contents, lang, identify = _get_contents()
|
22 |
+
|
23 |
+
print("The article has been retrieved, and the number of text fragments is:", len(contents))
|
24 |
+
for content in contents:
|
25 |
+
print('\t', content)
|
26 |
+
|
27 |
+
ai = AI(cfg)
|
28 |
+
storage = Storage.create_storage(cfg)
|
29 |
+
|
30 |
+
print("=====================================")
|
31 |
+
if storage.been_indexed(identify):
|
32 |
+
print("The article has already been indexed, so there is no need to index it again.")
|
33 |
+
print("=====================================")
|
34 |
+
else:
|
35 |
+
# 1. 对文章的每个段落生成embedding
|
36 |
+
# 1. Generate an embedding for each paragraph of the article.
|
37 |
+
embeddings, tokens = ai.create_embeddings(contents)
|
38 |
+
print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, "
|
39 |
+
f"costing ${tokens / 1000 * 0.0004}")
|
40 |
+
|
41 |
+
storage.add_all(embeddings, identify)
|
42 |
+
print("The embeddings have been saved.")
|
43 |
+
print("=====================================")
|
44 |
+
|
45 |
+
while True:
|
46 |
+
query = input("Please enter your query (/help to view commands):").strip()
|
47 |
+
if query.startswith("/"):
|
48 |
+
if query == "/quit":
|
49 |
+
return False
|
50 |
+
elif query == "/reset":
|
51 |
+
print("=====================================")
|
52 |
+
return True
|
53 |
+
elif query == "/summary":
|
54 |
+
# 生成embedding式摘要,根据不同的语言使用有基于SIF的加权平均或一般的直接求平均
|
55 |
+
# Generate an embedding-based summary, using weighted average based on SIF or direct average based on the language.
|
56 |
+
ai.generate_summary(storage.get_all_embeddings(identify), num_candidates=100,
|
57 |
+
use_sif=lang not in ['zh', 'ja', 'ko', 'hi', 'ar', 'fa'])
|
58 |
+
elif query == "/reindex":
|
59 |
+
# 重新索引,会清空数据库
|
60 |
+
# Re-index, which will clear the database.
|
61 |
+
storage.clear(identify)
|
62 |
+
embeddings, tokens = ai.create_embeddings(contents)
|
63 |
+
print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, "
|
64 |
+
f"costing ${tokens / 1000 * 0.0004}")
|
65 |
+
|
66 |
+
storage.add_all(embeddings, identify)
|
67 |
+
print("The embeddings have been saved.")
|
68 |
+
elif query == "/help":
|
69 |
+
print("Enter /summary to generate an embedding-based summary.")
|
70 |
+
print("Enter /reindex to re-index the article.")
|
71 |
+
print("Enter /reset to start over.")
|
72 |
+
print("Enter /quit to exit.")
|
73 |
+
print("Enter any other content for a query.")
|
74 |
+
else:
|
75 |
+
print("Invalid command.")
|
76 |
+
print("Enter /summary to generate an embedding-based summary.")
|
77 |
+
print("Enter /reindex to re-index the article.")
|
78 |
+
print("Enter /reset to start over.")
|
79 |
+
print("Enter /quit to exit.")
|
80 |
+
print("Enter any other content for a query.")
|
81 |
+
print("=====================================")
|
82 |
+
continue
|
83 |
+
else:
|
84 |
+
# 1. 生成关键词
|
85 |
+
# 1. Generate keywords.
|
86 |
+
print("Generate keywords.")
|
87 |
+
keywords = ai.get_keywords(query)
|
88 |
+
# 2. 对问题生成embedding
|
89 |
+
# 2. Generate an embedding for the question.
|
90 |
+
_, embedding = ai.create_embedding(keywords)
|
91 |
+
# 3. 从数据库中找到最相似的片段
|
92 |
+
# 3. Find the most similar fragments from the database.
|
93 |
+
texts = storage.get_texts(embedding, identify)
|
94 |
+
print("Related fragments found (first 5):")
|
95 |
+
for text in texts[:5]:
|
96 |
+
print('\t', text)
|
97 |
+
# 4. 把相关片段推给AI,AI会根据这些片段回答问题
|
98 |
+
# 4. Push the relevant fragments to the AI, which will answer the question based on these fragments.
|
99 |
+
ai.completion(query, texts)
|
100 |
+
print("=====================================")
|
101 |
+
|
102 |
+
|
103 |
+
def _get_contents() -> tuple[list[str], str, str]:
|
104 |
+
"""Get the contents."""
|
105 |
+
|
106 |
+
while True:
|
107 |
+
try:
|
108 |
+
url = input("Please enter the link to the article or the file path of the PDF/TXT/DOCX document: ").strip()
|
109 |
+
if os.path.exists(url):
|
110 |
+
if url.endswith('.pdf'):
|
111 |
+
contents, data = extract_text_from_pdf(url)
|
112 |
+
elif url.endswith('.txt'):
|
113 |
+
contents, data = extract_text_from_txt(url)
|
114 |
+
elif url.endswith('.docx'):
|
115 |
+
contents, data = extract_text_from_docx(url)
|
116 |
+
else:
|
117 |
+
print("Unsupported file format.")
|
118 |
+
continue
|
119 |
+
else:
|
120 |
+
contents, data = web_crawler_newspaper(url)
|
121 |
+
if not contents:
|
122 |
+
print("Unable to retrieve the content of the article. Please enter the link to the article or "
|
123 |
+
"the file path of the PDF/TXT/DOCX document again.")
|
124 |
+
continue
|
125 |
+
return contents, data, xxhash.xxh3_128_hexdigest('\n'.join(contents))
|
126 |
+
except Exception as e:
|
127 |
+
print("Error:", e)
|
contents.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import PyPDF2
|
5 |
+
import docx
|
6 |
+
import readability
|
7 |
+
from langdetect import detect
|
8 |
+
from newspaper import fulltext, Article
|
9 |
+
from selenium import webdriver
|
10 |
+
|
11 |
+
|
12 |
+
def web_crawler_newspaper(url: str) -> tuple[list[str], str]:
|
13 |
+
"""Run the web crawler."""
|
14 |
+
raw_html, lang = _get_raw_html(url)
|
15 |
+
try:
|
16 |
+
text = fulltext(raw_html, language=lang)
|
17 |
+
except:
|
18 |
+
article = Article(url)
|
19 |
+
article.download()
|
20 |
+
article.parse()
|
21 |
+
text = article.text
|
22 |
+
contents = [text.strip() for text in text.splitlines() if text.strip()]
|
23 |
+
return contents, lang
|
24 |
+
|
25 |
+
|
26 |
+
def _get_raw_html(url):
|
27 |
+
chrome_options = webdriver.ChromeOptions()
|
28 |
+
chrome_options.add_argument('--headless')
|
29 |
+
chrome_options.add_argument('--disable-gpu')
|
30 |
+
chrome_options.add_argument('--no-sandbox')
|
31 |
+
chrome_options.add_argument('--disable-dev-shm-usage')
|
32 |
+
chrome_options.add_argument('--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
|
33 |
+
'AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36')
|
34 |
+
|
35 |
+
with webdriver.Chrome(options=chrome_options) as driver:
|
36 |
+
driver.get(url)
|
37 |
+
print("Please wait for 5 seconds until the webpage finishes loading.")
|
38 |
+
time.sleep(5)
|
39 |
+
html = driver.page_source
|
40 |
+
|
41 |
+
doc = readability.Document(html)
|
42 |
+
html = doc.summary()
|
43 |
+
lang = detect(html)
|
44 |
+
return html, lang[0:2]
|
45 |
+
|
46 |
+
|
47 |
+
def extract_text_from_pdf(file_path: str) -> tuple[list[str], str]:
|
48 |
+
"""Extract text content from a PDF file."""
|
49 |
+
with open(file_path, 'rb') as f:
|
50 |
+
pdf_reader = PyPDF2.PdfReader(f)
|
51 |
+
contents = []
|
52 |
+
for page in pdf_reader.pages:
|
53 |
+
page_text = page.extract_text().strip()
|
54 |
+
raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
|
55 |
+
new_text = ''
|
56 |
+
for text in raw_text:
|
57 |
+
new_text += text
|
58 |
+
if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
|
59 |
+
'』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
|
60 |
+
contents.append(new_text)
|
61 |
+
new_text = ''
|
62 |
+
if new_text:
|
63 |
+
contents.append(new_text)
|
64 |
+
lang = detect('\n'.join(contents))
|
65 |
+
return contents, lang[0:2]
|
66 |
+
|
67 |
+
|
68 |
+
def extract_text_from_txt(file_path: str) -> tuple[list[str], str]:
|
69 |
+
"""Extract text content from a TXT file."""
|
70 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
71 |
+
contents = [text.strip() for text in f.readlines() if text.strip()]
|
72 |
+
lang = detect('\n'.join(contents))
|
73 |
+
return contents, lang[0:2]
|
74 |
+
|
75 |
+
|
76 |
+
def extract_text_from_docx(file_path: str) -> tuple[list[str], str]:
|
77 |
+
"""Extract text content from a DOCX file."""
|
78 |
+
document = docx.Document(file_path)
|
79 |
+
contents = [paragraph.text.strip() for paragraph in document.paragraphs if paragraph.text.strip()]
|
80 |
+
lang = detect('\n'.join(contents))
|
81 |
+
return contents, lang[0:2]
|
index/bkp/dd771cb6c4718ace4c4c596f4792cfdd.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c69dbad0041bb5e8e4b1e26793f8f38b56f5abb7c2db2dad201d13ce3a041d1
|
3 |
+
size 3408298
|
index/bkp/dd771cb6c4718ace4c4c596f4792cfdd.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
index/dd771cb6c4718ace4c4c596f4792cfdd.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c69dbad0041bb5e8e4b1e26793f8f38b56f5abb7c2db2dad201d13ce3a041d1
|
3 |
+
size 3408298
|
index/dd771cb6c4718ace4c4c596f4792cfdd.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
main.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from api import api
|
5 |
+
from config import Config
|
6 |
+
from console import console
|
7 |
+
from webui import webui
|
8 |
+
|
9 |
+
|
10 |
+
def run():
|
11 |
+
"""Run the program."""
|
12 |
+
cfg = Config()
|
13 |
+
|
14 |
+
mode = cfg.mode
|
15 |
+
if mode == 'console':
|
16 |
+
console(cfg)
|
17 |
+
elif mode == 'api':
|
18 |
+
api(cfg)
|
19 |
+
elif mode == 'webui':
|
20 |
+
webui(cfg)
|
21 |
+
else:
|
22 |
+
raise ValueError('mode must be console or api')
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == '__main__':
|
26 |
+
run()
|
requirements.txt
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp
|
2 |
+
aiosignal
|
3 |
+
anyio
|
4 |
+
async-generator
|
5 |
+
async-timeout
|
6 |
+
attrs
|
7 |
+
beautifulsoup4
|
8 |
+
certifi
|
9 |
+
cffi
|
10 |
+
chardet
|
11 |
+
charset-normalizer
|
12 |
+
click
|
13 |
+
colorama
|
14 |
+
cssselect
|
15 |
+
exceptiongroup
|
16 |
+
faiss-cpu
|
17 |
+
fastapi
|
18 |
+
feedfinder2
|
19 |
+
feedparser
|
20 |
+
filelock
|
21 |
+
frozenlist
|
22 |
+
greenlet
|
23 |
+
h11
|
24 |
+
httptools
|
25 |
+
idna
|
26 |
+
jieba3k
|
27 |
+
joblib
|
28 |
+
langdetect
|
29 |
+
lxml
|
30 |
+
multidict
|
31 |
+
newspaper3k
|
32 |
+
nltk
|
33 |
+
numpy
|
34 |
+
openai
|
35 |
+
outcome
|
36 |
+
pandas
|
37 |
+
pgvector
|
38 |
+
Pillow
|
39 |
+
pycparser
|
40 |
+
pydantic
|
41 |
+
PyPDF2
|
42 |
+
PySocks
|
43 |
+
python-dateutil
|
44 |
+
python-docx
|
45 |
+
python-dotenv
|
46 |
+
python-multipart
|
47 |
+
pytz
|
48 |
+
PyYAML
|
49 |
+
readability-lxml
|
50 |
+
regex
|
51 |
+
requests
|
52 |
+
requests-file
|
53 |
+
scikit-learn
|
54 |
+
scipy
|
55 |
+
selenium
|
56 |
+
sgmllib3k
|
57 |
+
six
|
58 |
+
sniffio
|
59 |
+
sortedcontainers
|
60 |
+
soupsieve
|
61 |
+
SQLAlchemy
|
62 |
+
starlette
|
63 |
+
threadpoolctl
|
64 |
+
tiktoken
|
65 |
+
tinysegmenter
|
66 |
+
tldextract
|
67 |
+
tqdm
|
68 |
+
trio
|
69 |
+
trio-websocket
|
70 |
+
typing_extensions
|
71 |
+
urllib3
|
72 |
+
uvicorn
|
73 |
+
watchfiles
|
74 |
+
websockets
|
75 |
+
wsproto
|
76 |
+
xxhash
|
77 |
+
yarl
|
78 |
+
gradio
|
79 |
+
psycopg2
|
storage.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
import faiss
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from pgvector.sqlalchemy import Vector
|
8 |
+
from sqlalchemy import create_engine, Column, Integer, String
|
9 |
+
from sqlalchemy.orm import sessionmaker, declarative_base
|
10 |
+
|
11 |
+
from config import Config
|
12 |
+
|
13 |
+
Base = declarative_base()
|
14 |
+
|
15 |
+
|
16 |
+
class Storage(ABC):
|
17 |
+
"""Abstract Storage class."""
|
18 |
+
|
19 |
+
# factory method
|
20 |
+
@staticmethod
|
21 |
+
def create_storage(cfg: Config) -> 'Storage':
|
22 |
+
"""Create a storage object."""
|
23 |
+
if cfg.use_postgres:
|
24 |
+
return _PostgresStorage(cfg)
|
25 |
+
else:
|
26 |
+
return _IndexStorage(cfg)
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def add_all(self, embeddings: list[tuple[str, list[float]]], name: str):
|
30 |
+
"""Add multiple embeddings."""
|
31 |
+
pass
|
32 |
+
|
33 |
+
@abstractmethod
|
34 |
+
def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]:
|
35 |
+
"""Get the text for the provided embedding."""
|
36 |
+
pass
|
37 |
+
|
38 |
+
@abstractmethod
|
39 |
+
def get_all_embeddings(self, name: str):
|
40 |
+
"""Get all embeddings."""
|
41 |
+
pass
|
42 |
+
|
43 |
+
@abstractmethod
|
44 |
+
def clear(self, name: str):
|
45 |
+
"""Clear the database."""
|
46 |
+
pass
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def been_indexed(self, name: str) -> bool:
|
50 |
+
"""Check if the database has been indexed."""
|
51 |
+
pass
|
52 |
+
|
53 |
+
|
54 |
+
class _IndexStorage(Storage):
|
55 |
+
"""IndexStorage class."""
|
56 |
+
|
57 |
+
def __init__(self, cfg: Config):
|
58 |
+
"""Initialize the storage."""
|
59 |
+
self._cfg = cfg
|
60 |
+
|
61 |
+
def add_all(self, embeddings: list[tuple[str, list[float]]], name):
|
62 |
+
"""Add multiple embeddings."""
|
63 |
+
texts, index = self._load(name)
|
64 |
+
ids = np.array([len(texts) + i for i, _ in enumerate(embeddings)])
|
65 |
+
texts = pd.concat([texts, pd.DataFrame(
|
66 |
+
{'index': len(texts) + i, 'text': text} for i, (text, _) in enumerate(embeddings))])
|
67 |
+
array = np.array([emb for text, emb in embeddings])
|
68 |
+
index.add_with_ids(array, ids)
|
69 |
+
self._save(texts, index, name)
|
70 |
+
|
71 |
+
def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]:
|
72 |
+
"""Get the text for the provided embedding."""
|
73 |
+
texts, index = self._load(name)
|
74 |
+
_, indexs = index.search(np.array([embedding]), limit)
|
75 |
+
indexs = [i for i in indexs[0] if i >= 0]
|
76 |
+
return [f'paragraph {p}: {t}' for _, p, t in texts.iloc[indexs].values]
|
77 |
+
|
78 |
+
def get_all_embeddings(self, name: str):
|
79 |
+
texts, index = self._load(name)
|
80 |
+
texts = texts.text.tolist()
|
81 |
+
embeddings = index.reconstruct_n(0, len(texts))
|
82 |
+
return list(zip(texts, embeddings))
|
83 |
+
|
84 |
+
def clear(self, name: str):
|
85 |
+
"""Clear the database."""
|
86 |
+
self._delete(name)
|
87 |
+
|
88 |
+
def been_indexed(self, name: str) -> bool:
|
89 |
+
return os.path.exists(os.path.join(self._cfg.index_path, f'{name}.csv')) and os.path.exists(
|
90 |
+
os.path.join(self._cfg.index_path, f'{name}.bin'))
|
91 |
+
|
92 |
+
def _save(self, texts, index, name: str):
|
93 |
+
texts.to_csv(os.path.join(self._cfg.index_path, f'{name}.csv'))
|
94 |
+
faiss.write_index(index, os.path.join(self._cfg.index_path, f'{name}.bin'))
|
95 |
+
|
96 |
+
def _load(self, name: str):
|
97 |
+
if self.been_indexed(name):
|
98 |
+
texts = pd.read_csv(os.path.join(self._cfg.index_path, f'{name}.csv'))
|
99 |
+
index = faiss.read_index(os.path.join(self._cfg.index_path, f'{name}.bin'))
|
100 |
+
else:
|
101 |
+
texts = pd.DataFrame(columns=['index', 'text'])
|
102 |
+
# IDMap2 with Flat
|
103 |
+
index = faiss.index_factory(1536, "IDMap2,Flat", faiss.METRIC_INNER_PRODUCT)
|
104 |
+
return texts, index
|
105 |
+
|
106 |
+
def _delete(self, name: str):
|
107 |
+
try:
|
108 |
+
os.remove(os.path.join(self._cfg.index_path, f'{name}.csv'))
|
109 |
+
os.remove(os.path.join(self._cfg.index_path, f'{name}.bin'))
|
110 |
+
except FileNotFoundError:
|
111 |
+
pass
|
112 |
+
|
113 |
+
|
114 |
+
def singleton(cls):
|
115 |
+
instances = {}
|
116 |
+
|
117 |
+
def get_instance(cfg):
|
118 |
+
if cls not in instances:
|
119 |
+
instances[cls] = cls(cfg)
|
120 |
+
return instances[cls]
|
121 |
+
|
122 |
+
return get_instance
|
123 |
+
|
124 |
+
|
125 |
+
@singleton
|
126 |
+
class _PostgresStorage(Storage):
|
127 |
+
"""PostgresStorage class."""
|
128 |
+
|
129 |
+
def __init__(self, cfg: Config):
|
130 |
+
"""Initialize the storage."""
|
131 |
+
self._postgresql = cfg.postgres_url
|
132 |
+
self._engine = create_engine(self._postgresql)
|
133 |
+
Base.metadata.create_all(self._engine)
|
134 |
+
session = sessionmaker(bind=self._engine)
|
135 |
+
self._session = session()
|
136 |
+
|
137 |
+
def add_all(self, embeddings: list[tuple[str, list[float]]], name: str):
|
138 |
+
"""Add multiple embeddings."""
|
139 |
+
data = [self.EmbeddingEntity(text=text, embedding=embedding, name=name) for text, embedding in embeddings]
|
140 |
+
self._session.add_all(data)
|
141 |
+
self._session.commit()
|
142 |
+
|
143 |
+
def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]:
|
144 |
+
"""Get the text for the provided embedding."""
|
145 |
+
result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).order_by(
|
146 |
+
self.EmbeddingEntity.embedding.cosine_distance(embedding)).limit(limit).all()
|
147 |
+
return [f'paragraph {s.id}: {s.text}' for s in result]
|
148 |
+
|
149 |
+
def get_all_embeddings(self, name: str):
|
150 |
+
"""Get all embeddings."""
|
151 |
+
result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).all()
|
152 |
+
return [(s.text, s.embedding) for s in result]
|
153 |
+
|
154 |
+
def clear(self, name: str):
|
155 |
+
"""Clear the database."""
|
156 |
+
self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).delete()
|
157 |
+
self._session.commit()
|
158 |
+
|
159 |
+
def been_indexed(self, name: str) -> bool:
|
160 |
+
return self._session.query(self.EmbeddingEntity).filter_by(name=name).first() is not None
|
161 |
+
|
162 |
+
def __del__(self):
|
163 |
+
"""Close the session."""
|
164 |
+
self._session.close()
|
165 |
+
|
166 |
+
class EmbeddingEntity(Base):
|
167 |
+
__tablename__ = 'embedding'
|
168 |
+
id = Column(Integer, primary_key=True)
|
169 |
+
name = Column(String)
|
170 |
+
text = Column(String)
|
171 |
+
embedding = Column(Vector(1536))
|
webui.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import xxhash
|
3 |
+
from gradio.components import _Keywords
|
4 |
+
|
5 |
+
from ai import AI
|
6 |
+
from config import Config
|
7 |
+
from contents import *
|
8 |
+
from storage import Storage
|
9 |
+
|
10 |
+
|
11 |
+
def webui(cfg: Config):
|
12 |
+
"""Run the web UI."""
|
13 |
+
Webui(cfg).run()
|
14 |
+
|
15 |
+
|
16 |
+
class Webui:
|
17 |
+
def __init__(self, cfg: Config):
|
18 |
+
self.cfg = cfg
|
19 |
+
self.ai = AI(cfg)
|
20 |
+
self.storage = Storage.create_storage(self.cfg) # Initialize storage here
|
21 |
+
|
22 |
+
def _save_to_storage(self, contents, hash_id):
|
23 |
+
print(f"Saving to storage {hash_id}")
|
24 |
+
print(f"Contents: \n{contents}")
|
25 |
+
self.storage = Storage.create_storage(self.cfg)
|
26 |
+
if self.storage.been_indexed(hash_id):
|
27 |
+
return 0
|
28 |
+
else:
|
29 |
+
embeddings, tokens = self.ai.create_embeddings(contents)
|
30 |
+
self.storage.add_all(embeddings, hash_id)
|
31 |
+
return tokens
|
32 |
+
|
33 |
+
def _get_hash_id(self, contents):
|
34 |
+
return xxhash.xxh3_128_hexdigest('\n'.join(contents))
|
35 |
+
|
36 |
+
def run(self):
|
37 |
+
with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {visibility: hidden}") as demo:
|
38 |
+
|
39 |
+
hash_id_state = gr.State('dd771cb6c4718ace4c4c596f4792cfdd') # Initialize hash_id_state to 'dd771cb6c4718ace4c4c596f4792cfdd'
|
40 |
+
chat_page = gr.Column(visible=True) # Set chat_page to visible by default
|
41 |
+
|
42 |
+
with chat_page:
|
43 |
+
with gr.Row():
|
44 |
+
with gr.Column():
|
45 |
+
chatbot = gr.Chatbot(label="Kanunla Konuş")
|
46 |
+
msg = gr.Textbox(label="Sorunuzu Yazın (Bu deneysel bir projedir, tam ve doğru bilgi için, uzmanlarımıza danışın)")
|
47 |
+
submit_box = gr.Button("Kanuna Sor", variant="primary")
|
48 |
+
|
49 |
+
def respond(message, chat_history, hash_id):
|
50 |
+
kw = self.ai.get_keywords(message)
|
51 |
+
if len(kw) == 0 or hash_id is None:
|
52 |
+
return "", chat_history
|
53 |
+
_, kw_ebd = self.ai.create_embedding(kw)
|
54 |
+
ctx = self.storage.get_texts(kw_ebd, hash_id)
|
55 |
+
print(f"Context: \n{ctx}")
|
56 |
+
bot_message = self.ai.completion(message, ctx)
|
57 |
+
chat_history.append((message, bot_message))
|
58 |
+
return "", chat_history, \
|
59 |
+
|
60 |
+
def reset():
|
61 |
+
return {
|
62 |
+
chat_page: gr.update(visible=True),
|
63 |
+
chatbot: gr.update(value=[]),
|
64 |
+
msg: gr.update(value=""),
|
65 |
+
hash_id_state: 'dd771cb6c4718ace4c4c596f4792cfdd',
|
66 |
+
}
|
67 |
+
|
68 |
+
msg.submit(respond, [msg, chatbot, hash_id_state], [msg, chatbot])
|
69 |
+
submit_box.click(respond, [msg, chatbot, hash_id_state], [msg, chatbot])
|
70 |
+
demo.title = "Kanuna Sor"
|
71 |
+
demo.launch(server_port=self.cfg.webui_port, server_name=self.cfg.webui_host, show_api=False)
|