serhan commited on
Commit
14e11d6
·
1 Parent(s): bd35fac

Upload 16 files

Browse files
README.md CHANGED
@@ -1,12 +1,11 @@
1
  ---
2
- title: I135e1fi414i41tqe
3
- emoji: 😻
4
- colorFrom: yellow
5
- colorTo: gray
 
6
  sdk: gradio
7
  sdk_version: 3.32.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ python_version: 3.10.6
3
+ title: Kanunasor
4
+ emoji: 🏆
5
+ colorFrom: red
6
+ colorTo: blue
7
  sdk: gradio
8
  sdk_version: 3.32.0
9
  app_file: app.py
10
  pinned: false
11
  ---
 
 
ai.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import openai
3
+ import tiktoken
4
+ from sklearn.feature_extraction.text import TfidfVectorizer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+
7
+ from config import Config
8
+
9
+
10
+ class AI:
11
+ """The AI class."""
12
+
13
+ def __init__(self, cfg: Config):
14
+ openai.api_key = cfg.open_ai_key
15
+ openai.proxy = cfg.open_ai_proxy
16
+ self._chat_model = cfg.open_ai_chat_model
17
+ self._use_stream = cfg.use_stream
18
+ self._encoding = tiktoken.encoding_for_model('gpt-3.5-turbo')
19
+ self._language = cfg.language
20
+ self._temperature = cfg.temperature
21
+
22
+ def _chat_stream(self, messages: list[dict], use_stream: bool = None) -> str:
23
+ use_stream = use_stream if use_stream is not None else self._use_stream
24
+ response = openai.ChatCompletion.create(
25
+ temperature=self._temperature,
26
+ stream=use_stream,
27
+ model=self._chat_model,
28
+ messages=messages,
29
+ )
30
+ if use_stream:
31
+ data = ""
32
+ for chunk in response:
33
+ if chunk.choices[0].delta.get('content', None) is not None:
34
+ data += chunk.choices[0].delta.content
35
+ print(chunk.choices[0].delta.content, end='')
36
+ print()
37
+ return data.strip()
38
+ else:
39
+ print(response.choices[0].message.content.strip())
40
+ print(f"Total tokens used: {response.usage.total_tokens}, "
41
+ f"cost: ${response.usage.total_tokens / 1000 * 0.002}")
42
+ return response.choices[0].message.content.strip()
43
+
44
+ def _num_tokens_from_string(self, string: str) -> int:
45
+ """Returns the number of tokens in a text string."""
46
+ num_tokens = len(self._encoding.encode(string))
47
+ return num_tokens
48
+
49
+ def completion(self, query: str, context: list[str]):
50
+ """Create a completion."""
51
+ context = self._cut_texts(context)
52
+ print(f"Number of query fragments:{len(context)}")
53
+
54
+ text = "\n".join(f"{index}. {text}" for index, text in enumerate(context))
55
+ result = self._chat_stream([
56
+ {'role': 'system',
57
+ 'content': f'You are a helpful AI article assistant. '
58
+ f'The following are the relevant article content fragments found from the article. '
59
+ f'The relevance is sorted from high to low. '
60
+ f'You can only answer according to the following content:\n```\n{text}\n```\n'
61
+ f'You need to carefully consider your answer to ensure that it is based on the context. '
62
+ f'If the context does not mention the content or it is uncertain whether it is correct, '
63
+ f'please answer "Bu bilgiye tam olarak hakim değilim, lütfen uzmanlarımıza danışın. Başka bir soru sorabilirsiniz."'
64
+ f'You must use {self._language} to respond.'},
65
+ {'role': 'user', 'content': query},
66
+ ])
67
+ return result
68
+
69
+ def _cut_texts(self, context):
70
+ maximum = 4096 - 1024
71
+ for index, text in enumerate(context):
72
+ maximum -= self._num_tokens_from_string(text)
73
+ if maximum < 0:
74
+ context = context[:index + 1]
75
+ print(f"Exceeded maximum length, cut the first {index + 1} fragments")
76
+ break
77
+ return context
78
+
79
+ def get_keywords(self, query: str) -> str:
80
+ """Get keywords from the query."""
81
+ result = self._chat_stream([
82
+ {'role': 'user',
83
+ 'content': f'You need to extract keywords from the statement or question and '
84
+ f'return a series of keywords separated by commas.\ncontent: {query}\nkeywords: '},
85
+ ], use_stream=False)
86
+ return result
87
+
88
+ @staticmethod
89
+ def create_embedding(text: str) -> (str, list[float]):
90
+ """Create an embedding for the provided text."""
91
+ embedding = openai.Embedding.create(model="text-embedding-ada-002", input=text)
92
+ return text, embedding.data[0].embedding
93
+
94
+ def create_embeddings(self, texts: list[str]) -> (list[tuple[str, list[float]]], int):
95
+ """Create embeddings for the provided input."""
96
+ result = []
97
+ query_len = 0
98
+ start_index = 0
99
+ tokens = 0
100
+
101
+ def get_embedding(input_slice: list[str]):
102
+ embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice)
103
+ return [(txt, data.embedding) for txt, data in
104
+ zip(input_slice, embedding.data)], embedding.usage.total_tokens
105
+
106
+ for index, text in enumerate(texts):
107
+ query_len += self._num_tokens_from_string(text)
108
+ if query_len > 8192 - 1024:
109
+ ebd, tk = get_embedding(texts[start_index:index + 1])
110
+ print(f"Query fragments used tokens: {tk}, cost: ${tk / 1000 * 0.0004}")
111
+ query_len = 0
112
+ start_index = index + 1
113
+ tokens += tk
114
+ result.extend(ebd)
115
+
116
+ if query_len > 0:
117
+ ebd, tk = get_embedding(texts[start_index:])
118
+ print(f"Query fragments used tokens: {tk}, cost: ${tk / 1000 * 0.0004}")
119
+ tokens += tk
120
+ result.extend(ebd)
121
+ return result, tokens
122
+
123
+ def generate_summary(self, embeddings, num_candidates=3, use_sif=False):
124
+ """Generate a summary for the provided embeddings."""
125
+ avg_func = self._calc_paragraph_avg_embedding_with_sif if use_sif else self._calc_avg_embedding
126
+ avg_embedding = np.array(avg_func(embeddings))
127
+
128
+ paragraphs = [e[0] for e in embeddings]
129
+ embeddings = np.array([e[1] for e in embeddings])
130
+ # 计算每个段落与整个文本的相似度分数
131
+ # Calculate the similarity score between each paragraph and the entire text.
132
+ similarity_scores = cosine_similarity(embeddings, avg_embedding.reshape(1, -1)).flatten()
133
+
134
+ # 选择具有最高相似度分数的段落作为摘要的候选段落
135
+ # Select the paragraph with the highest similarity score as the candidate paragraph for the summary.
136
+ candidate_indices = np.argsort(similarity_scores)[::-1][:num_candidates]
137
+ candidate_paragraphs = [f"paragraph {i}: {paragraphs[i]}" for i in candidate_indices]
138
+
139
+ print("Calculation completed, start generating summary")
140
+
141
+ candidate_paragraphs = self._cut_texts(candidate_paragraphs)
142
+
143
+ text = "\n".join(f"{index}. {text}" for index, text in enumerate(candidate_paragraphs))
144
+ result = self._chat_stream([
145
+ {'role': 'system',
146
+ 'content': f'As a helpful AI article assistant, '
147
+ f'I have retrieved the following relevant text fragments from the article, '
148
+ f'sorted by relevance from high to low. '
149
+ f'You need to summarize the entire article from these fragments, '
150
+ f'and present the final result in {self._language}:\n\n{text}\n\n{self._language} summary:'},
151
+ ])
152
+ return result
153
+
154
+ @staticmethod
155
+ def _calc_avg_embedding(embeddings) -> list[float]:
156
+ # Calculate the average embedding for the entire text.
157
+ avg_embedding = np.zeros(len(embeddings[0][1]))
158
+ for emb in embeddings:
159
+ avg_embedding += np.array(emb[1])
160
+ avg_embedding /= len(embeddings)
161
+ return avg_embedding.tolist()
162
+
163
+ @staticmethod
164
+ def _calc_paragraph_avg_embedding_with_sif(paragraph_list) -> list[float]:
165
+ # calculate the SIF embedding for the entire text
166
+ alpha = 0.001
167
+ # calculate the total number of sentences
168
+ n_sentences = len(paragraph_list)
169
+
170
+ # calculate the total number of dimensions in the embeddings
171
+ n_dims = len(paragraph_list[0][1])
172
+
173
+ # calculate the IDF values for each word in the sentences
174
+ vectorizer = TfidfVectorizer(use_idf=True)
175
+ vectorizer.fit_transform([paragraph for paragraph, _ in paragraph_list])
176
+ idf = vectorizer.idf_
177
+
178
+ # calculate the SIF weights for each sentence
179
+ weights = np.zeros((n_sentences, n_dims))
180
+ for i, (sentence, embedding) in enumerate(paragraph_list):
181
+ sentence_words = sentence.split()
182
+ for word in sentence_words:
183
+ try:
184
+ word_index = vectorizer.vocabulary_[word]
185
+ word_idf = idf[word_index]
186
+ word_weight = alpha / (alpha + word_idf)
187
+ weights[i] += word_weight * (np.array(embedding) / np.max(embedding))
188
+ except KeyError:
189
+ pass
190
+
191
+ # calculate the weighted average of the sentence embeddings
192
+ weights_sum = np.sum(weights, axis=0)
193
+ weights_sum /= n_sentences
194
+ avg_embedding = np.zeros(n_dims)
195
+ for i, (sentence, embedding) in enumerate(paragraph_list):
196
+ avg_embedding += (np.array(embedding) / np.max(embedding)) - weights[i]
197
+ avg_embedding /= n_sentences
198
+
199
+ return avg_embedding.tolist()
api.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import uvicorn
5
+ import xxhash
6
+ from fastapi import FastAPI, UploadFile, File
7
+ from fastapi.exceptions import RequestValidationError
8
+ from pydantic import BaseModel
9
+ from starlette.exceptions import HTTPException
10
+ from starlette.requests import Request
11
+ from starlette.responses import JSONResponse
12
+
13
+ from ai import AI
14
+ from config import Config
15
+ from contents import web_crawler_newspaper, extract_text_from_txt, extract_text_from_docx, \
16
+ extract_text_from_pdf
17
+ from storage import Storage
18
+
19
+
20
+ def api(cfg: Config):
21
+ """Run the API."""
22
+
23
+ cfg.use_stream = False
24
+ ai = AI(cfg)
25
+
26
+ app = FastAPI()
27
+
28
+ class CrawlerUrlRequest(BaseModel):
29
+ url: str
30
+
31
+ @app.post("/crawler_url")
32
+ async def crawler_url(req: CrawlerUrlRequest):
33
+ """Crawler the URL."""
34
+ contents, lang = web_crawler_newspaper(req.url)
35
+ hash_id = xxhash.xxh3_128_hexdigest('\n'.join(contents))
36
+ tokens = _save_to_storage(contents, hash_id)
37
+ return {"code": 0, "msg": "ok", "data": {"uri": f"{hash_id}/{lang}", "tokens": tokens}}
38
+
39
+ def _save_to_storage(contents, hash_id):
40
+ storage = Storage.create_storage(cfg)
41
+ if storage.been_indexed(hash_id):
42
+ return 0
43
+ else:
44
+ embeddings, tokens = ai.create_embeddings(contents)
45
+ storage.add_all(embeddings, hash_id)
46
+ return tokens
47
+
48
+ @app.post("/upload_file")
49
+ async def create_upload_file(file: UploadFile = File(...)):
50
+ """Upload file."""
51
+ # save file to disk
52
+ file_name = file.filename
53
+ os.makedirs('./upload', exist_ok=True)
54
+ upload_path = os.path.join('./upload', file_name)
55
+ with open(upload_path, "wb") as buffer:
56
+ shutil.copyfileobj(file.file, buffer)
57
+ if file_name.endswith('.pdf'):
58
+ contents, lang = extract_text_from_pdf(upload_path)
59
+ elif file_name.endswith('.txt'):
60
+ contents, lang = extract_text_from_txt(upload_path)
61
+ elif file_name.endswith('.docx'):
62
+ contents, lang = extract_text_from_docx(upload_path)
63
+ else:
64
+ return {"code": 1, "msg": "not support", "data": {}}
65
+ hash_id = xxhash.xxh3_128_hexdigest('\n'.join(contents))
66
+ tokens = _save_to_storage(contents, hash_id)
67
+ os.remove(upload_path)
68
+ return {"code": 0, "msg": "ok", "data": {"uri": f"{hash_id}/{lang}", "tokens": tokens}}
69
+
70
+ @app.get("/summary")
71
+ async def summary(uri: str):
72
+ """Generate summary."""
73
+ hash_id, lang = uri.split('/')
74
+ storage = Storage.create_storage(cfg)
75
+ if not storage or not lang:
76
+ return {"code": 1, "msg": "not found", "data": {}}
77
+ s = ai.generate_summary(storage.get_all_embeddings(hash_id), num_candidates=100,
78
+ use_sif=lang not in ['zh', 'ja', 'ko', 'hi', 'ar', 'fa'])
79
+ return {"code": 0, "msg": "ok", "data": {"summary": s}}
80
+
81
+ class AnswerRequest(BaseModel):
82
+ uri: str
83
+ query: str
84
+
85
+ @app.get("/answer")
86
+ async def answer(req: AnswerRequest):
87
+ """Query."""
88
+ hash_id, lang = req.uri.split('/')
89
+ storage = Storage.create_storage(cfg)
90
+ if not storage or not lang:
91
+ return {"code": 1, "msg": "not found", "data": {}}
92
+ keywords = ai.get_keywords(req.query)
93
+ _, embedding = ai.create_embedding(keywords)
94
+ texts = storage.get_texts(embedding, hash_id)
95
+ s = ai.completion(req.query, texts)
96
+ return {"code": 0, "msg": "ok", "data": {"answer": s}}
97
+
98
+ @app.exception_handler(RequestValidationError)
99
+ async def validate_error_handler(request: Request, exc: RequestValidationError):
100
+ """Error handler."""
101
+ print("validate_error_handler: ", request.url, exc)
102
+ return JSONResponse(
103
+ status_code=400,
104
+ content={"code": 1, "msg": str(exc.errors()), "data": {}},
105
+ )
106
+
107
+ @app.exception_handler(HTTPException)
108
+ async def http_error_handler(request: Request, exc):
109
+ """Error handler."""
110
+ print("http error_handler: ", request.url, exc)
111
+ return JSONResponse(
112
+ status_code=400,
113
+ content={"code": 1, "msg": exc.detail, "data": {}},
114
+ )
115
+
116
+ # run the API
117
+ uvicorn.run(app, host=cfg.api_host, port=cfg.api_port)
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from api import api
5
+ from config import Config
6
+ from console import console
7
+ from webui import webui
8
+
9
+
10
+ def run():
11
+ """Run the program."""
12
+ cfg = Config()
13
+
14
+ mode = cfg.mode
15
+ if mode == 'console':
16
+ console(cfg)
17
+ elif mode == 'api':
18
+ api(cfg)
19
+ elif mode == 'webui':
20
+ webui(cfg)
21
+ else:
22
+ raise ValueError('mode must be console or api')
23
+
24
+
25
+ if __name__ == '__main__':
26
+ run()
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "temperature": 0.1,
3
+ "language": "Turkish",
4
+ "open_ai_chat_model": "gpt-3.5-turbo",
5
+ "use_stream": false,
6
+ "use_postgres": false,
7
+ "index_path": "./index",
8
+ "postgres_url": "postgresql://localhost:5432/mydb",
9
+ "mode": "webui",
10
+ "api_port": 9531,
11
+ "api_host": "localhost",
12
+ "webui_port": 7860,
13
+ "webui_host": "0.0.0.0"
14
+ }
config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+
5
+ class Config:
6
+ def __init__(self):
7
+ config_path = os.path.join(os.path.dirname(__file__), 'config.json')
8
+ if not os.path.exists(config_path):
9
+ raise FileNotFoundError(f'config.json not found at {config_path}, '
10
+ f'please copy config.example.json to config.json and modify it.')
11
+ with open(config_path, 'r') as f:
12
+ self.open_ai_key = os.environ['open_ai_key']
13
+ self.config = json.load(f)
14
+ self.language = self.config.get('language', 'Chinese')
15
+ self.open_ai_proxy = self.config.get('open_ai_proxy')
16
+ self.open_ai_chat_model = self.config.get('open_ai_chat_model', 'gpt-3.5-turbo')
17
+ if not self.open_ai_key:
18
+ raise ValueError('open_ai_key is not set')
19
+ self.temperature = self.config.get('temperature', 0.1)
20
+ if self.temperature < 0 or self.temperature > 1:
21
+ raise ValueError('temperature must be between 0 and 1, less is more conservative, more is more creative')
22
+ self.use_stream = self.config.get('use_stream', False)
23
+ self.use_postgres = self.config.get('use_postgres', False)
24
+ if not self.use_postgres:
25
+ self.index_path = self.config.get('index_path', './temp')
26
+ os.makedirs(self.index_path, exist_ok=True)
27
+ self.postgres_url = self.config.get('postgres_url')
28
+ if self.use_postgres and self.postgres_url is None:
29
+ raise ValueError('postgres_url is not set')
30
+ self.mode = self.config.get('mode', 'webui')
31
+ if self.mode not in ['console', 'api', 'webui']:
32
+ raise ValueError('mode must be console or api or webui')
33
+ self.api_port = self.config.get('api_port', 9531)
34
+ self.api_host = self.config.get('api_host', 'localhost')
35
+ self.webui_port = self.config.get('webui_port', 7860)
36
+ self.webui_host = self.config.get('webui_host', '0.0.0.0')
console.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xxhash
2
+
3
+ from ai import AI
4
+ from config import Config
5
+ from storage import Storage
6
+ from contents import *
7
+
8
+
9
+ def console(cfg: Config):
10
+ try:
11
+ while True:
12
+ if not _console(cfg):
13
+ return
14
+ except KeyboardInterrupt:
15
+ print("exit")
16
+
17
+
18
+ def _console(cfg: Config) -> bool:
19
+ """Run the console."""
20
+
21
+ contents, lang, identify = _get_contents()
22
+
23
+ print("The article has been retrieved, and the number of text fragments is:", len(contents))
24
+ for content in contents:
25
+ print('\t', content)
26
+
27
+ ai = AI(cfg)
28
+ storage = Storage.create_storage(cfg)
29
+
30
+ print("=====================================")
31
+ if storage.been_indexed(identify):
32
+ print("The article has already been indexed, so there is no need to index it again.")
33
+ print("=====================================")
34
+ else:
35
+ # 1. 对文章的每个段落生成embedding
36
+ # 1. Generate an embedding for each paragraph of the article.
37
+ embeddings, tokens = ai.create_embeddings(contents)
38
+ print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, "
39
+ f"costing ${tokens / 1000 * 0.0004}")
40
+
41
+ storage.add_all(embeddings, identify)
42
+ print("The embeddings have been saved.")
43
+ print("=====================================")
44
+
45
+ while True:
46
+ query = input("Please enter your query (/help to view commands):").strip()
47
+ if query.startswith("/"):
48
+ if query == "/quit":
49
+ return False
50
+ elif query == "/reset":
51
+ print("=====================================")
52
+ return True
53
+ elif query == "/summary":
54
+ # 生成embedding式摘要,根据不同的语言使用有基于SIF的加权平均或一般的直接求平均
55
+ # Generate an embedding-based summary, using weighted average based on SIF or direct average based on the language.
56
+ ai.generate_summary(storage.get_all_embeddings(identify), num_candidates=100,
57
+ use_sif=lang not in ['zh', 'ja', 'ko', 'hi', 'ar', 'fa'])
58
+ elif query == "/reindex":
59
+ # 重新索引,会清空数据库
60
+ # Re-index, which will clear the database.
61
+ storage.clear(identify)
62
+ embeddings, tokens = ai.create_embeddings(contents)
63
+ print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, "
64
+ f"costing ${tokens / 1000 * 0.0004}")
65
+
66
+ storage.add_all(embeddings, identify)
67
+ print("The embeddings have been saved.")
68
+ elif query == "/help":
69
+ print("Enter /summary to generate an embedding-based summary.")
70
+ print("Enter /reindex to re-index the article.")
71
+ print("Enter /reset to start over.")
72
+ print("Enter /quit to exit.")
73
+ print("Enter any other content for a query.")
74
+ else:
75
+ print("Invalid command.")
76
+ print("Enter /summary to generate an embedding-based summary.")
77
+ print("Enter /reindex to re-index the article.")
78
+ print("Enter /reset to start over.")
79
+ print("Enter /quit to exit.")
80
+ print("Enter any other content for a query.")
81
+ print("=====================================")
82
+ continue
83
+ else:
84
+ # 1. 生成关键词
85
+ # 1. Generate keywords.
86
+ print("Generate keywords.")
87
+ keywords = ai.get_keywords(query)
88
+ # 2. 对问题生成embedding
89
+ # 2. Generate an embedding for the question.
90
+ _, embedding = ai.create_embedding(keywords)
91
+ # 3. 从数据库中找到最相似的片段
92
+ # 3. Find the most similar fragments from the database.
93
+ texts = storage.get_texts(embedding, identify)
94
+ print("Related fragments found (first 5):")
95
+ for text in texts[:5]:
96
+ print('\t', text)
97
+ # 4. 把相关片段推给AI,AI会根据这些片段回答问题
98
+ # 4. Push the relevant fragments to the AI, which will answer the question based on these fragments.
99
+ ai.completion(query, texts)
100
+ print("=====================================")
101
+
102
+
103
+ def _get_contents() -> tuple[list[str], str, str]:
104
+ """Get the contents."""
105
+
106
+ while True:
107
+ try:
108
+ url = input("Please enter the link to the article or the file path of the PDF/TXT/DOCX document: ").strip()
109
+ if os.path.exists(url):
110
+ if url.endswith('.pdf'):
111
+ contents, data = extract_text_from_pdf(url)
112
+ elif url.endswith('.txt'):
113
+ contents, data = extract_text_from_txt(url)
114
+ elif url.endswith('.docx'):
115
+ contents, data = extract_text_from_docx(url)
116
+ else:
117
+ print("Unsupported file format.")
118
+ continue
119
+ else:
120
+ contents, data = web_crawler_newspaper(url)
121
+ if not contents:
122
+ print("Unable to retrieve the content of the article. Please enter the link to the article or "
123
+ "the file path of the PDF/TXT/DOCX document again.")
124
+ continue
125
+ return contents, data, xxhash.xxh3_128_hexdigest('\n'.join(contents))
126
+ except Exception as e:
127
+ print("Error:", e)
contents.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import PyPDF2
5
+ import docx
6
+ import readability
7
+ from langdetect import detect
8
+ from newspaper import fulltext, Article
9
+ from selenium import webdriver
10
+
11
+
12
+ def web_crawler_newspaper(url: str) -> tuple[list[str], str]:
13
+ """Run the web crawler."""
14
+ raw_html, lang = _get_raw_html(url)
15
+ try:
16
+ text = fulltext(raw_html, language=lang)
17
+ except:
18
+ article = Article(url)
19
+ article.download()
20
+ article.parse()
21
+ text = article.text
22
+ contents = [text.strip() for text in text.splitlines() if text.strip()]
23
+ return contents, lang
24
+
25
+
26
+ def _get_raw_html(url):
27
+ chrome_options = webdriver.ChromeOptions()
28
+ chrome_options.add_argument('--headless')
29
+ chrome_options.add_argument('--disable-gpu')
30
+ chrome_options.add_argument('--no-sandbox')
31
+ chrome_options.add_argument('--disable-dev-shm-usage')
32
+ chrome_options.add_argument('--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
33
+ 'AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36')
34
+
35
+ with webdriver.Chrome(options=chrome_options) as driver:
36
+ driver.get(url)
37
+ print("Please wait for 5 seconds until the webpage finishes loading.")
38
+ time.sleep(5)
39
+ html = driver.page_source
40
+
41
+ doc = readability.Document(html)
42
+ html = doc.summary()
43
+ lang = detect(html)
44
+ return html, lang[0:2]
45
+
46
+
47
+ def extract_text_from_pdf(file_path: str) -> tuple[list[str], str]:
48
+ """Extract text content from a PDF file."""
49
+ with open(file_path, 'rb') as f:
50
+ pdf_reader = PyPDF2.PdfReader(f)
51
+ contents = []
52
+ for page in pdf_reader.pages:
53
+ page_text = page.extract_text().strip()
54
+ raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
55
+ new_text = ''
56
+ for text in raw_text:
57
+ new_text += text
58
+ if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
59
+ '』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
60
+ contents.append(new_text)
61
+ new_text = ''
62
+ if new_text:
63
+ contents.append(new_text)
64
+ lang = detect('\n'.join(contents))
65
+ return contents, lang[0:2]
66
+
67
+
68
+ def extract_text_from_txt(file_path: str) -> tuple[list[str], str]:
69
+ """Extract text content from a TXT file."""
70
+ with open(file_path, 'r', encoding='utf-8') as f:
71
+ contents = [text.strip() for text in f.readlines() if text.strip()]
72
+ lang = detect('\n'.join(contents))
73
+ return contents, lang[0:2]
74
+
75
+
76
+ def extract_text_from_docx(file_path: str) -> tuple[list[str], str]:
77
+ """Extract text content from a DOCX file."""
78
+ document = docx.Document(file_path)
79
+ contents = [paragraph.text.strip() for paragraph in document.paragraphs if paragraph.text.strip()]
80
+ lang = detect('\n'.join(contents))
81
+ return contents, lang[0:2]
index/bkp/dd771cb6c4718ace4c4c596f4792cfdd.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c69dbad0041bb5e8e4b1e26793f8f38b56f5abb7c2db2dad201d13ce3a041d1
3
+ size 3408298
index/bkp/dd771cb6c4718ace4c4c596f4792cfdd.csv ADDED
The diff for this file is too large to render. See raw diff
 
index/dd771cb6c4718ace4c4c596f4792cfdd.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c69dbad0041bb5e8e4b1e26793f8f38b56f5abb7c2db2dad201d13ce3a041d1
3
+ size 3408298
index/dd771cb6c4718ace4c4c596f4792cfdd.csv ADDED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from api import api
5
+ from config import Config
6
+ from console import console
7
+ from webui import webui
8
+
9
+
10
+ def run():
11
+ """Run the program."""
12
+ cfg = Config()
13
+
14
+ mode = cfg.mode
15
+ if mode == 'console':
16
+ console(cfg)
17
+ elif mode == 'api':
18
+ api(cfg)
19
+ elif mode == 'webui':
20
+ webui(cfg)
21
+ else:
22
+ raise ValueError('mode must be console or api')
23
+
24
+
25
+ if __name__ == '__main__':
26
+ run()
requirements.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp
2
+ aiosignal
3
+ anyio
4
+ async-generator
5
+ async-timeout
6
+ attrs
7
+ beautifulsoup4
8
+ certifi
9
+ cffi
10
+ chardet
11
+ charset-normalizer
12
+ click
13
+ colorama
14
+ cssselect
15
+ exceptiongroup
16
+ faiss-cpu
17
+ fastapi
18
+ feedfinder2
19
+ feedparser
20
+ filelock
21
+ frozenlist
22
+ greenlet
23
+ h11
24
+ httptools
25
+ idna
26
+ jieba3k
27
+ joblib
28
+ langdetect
29
+ lxml
30
+ multidict
31
+ newspaper3k
32
+ nltk
33
+ numpy
34
+ openai
35
+ outcome
36
+ pandas
37
+ pgvector
38
+ Pillow
39
+ pycparser
40
+ pydantic
41
+ PyPDF2
42
+ PySocks
43
+ python-dateutil
44
+ python-docx
45
+ python-dotenv
46
+ python-multipart
47
+ pytz
48
+ PyYAML
49
+ readability-lxml
50
+ regex
51
+ requests
52
+ requests-file
53
+ scikit-learn
54
+ scipy
55
+ selenium
56
+ sgmllib3k
57
+ six
58
+ sniffio
59
+ sortedcontainers
60
+ soupsieve
61
+ SQLAlchemy
62
+ starlette
63
+ threadpoolctl
64
+ tiktoken
65
+ tinysegmenter
66
+ tldextract
67
+ tqdm
68
+ trio
69
+ trio-websocket
70
+ typing_extensions
71
+ urllib3
72
+ uvicorn
73
+ watchfiles
74
+ websockets
75
+ wsproto
76
+ xxhash
77
+ yarl
78
+ gradio
79
+ psycopg2
storage.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from abc import ABC, abstractmethod
3
+
4
+ import faiss
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pgvector.sqlalchemy import Vector
8
+ from sqlalchemy import create_engine, Column, Integer, String
9
+ from sqlalchemy.orm import sessionmaker, declarative_base
10
+
11
+ from config import Config
12
+
13
+ Base = declarative_base()
14
+
15
+
16
+ class Storage(ABC):
17
+ """Abstract Storage class."""
18
+
19
+ # factory method
20
+ @staticmethod
21
+ def create_storage(cfg: Config) -> 'Storage':
22
+ """Create a storage object."""
23
+ if cfg.use_postgres:
24
+ return _PostgresStorage(cfg)
25
+ else:
26
+ return _IndexStorage(cfg)
27
+
28
+ @abstractmethod
29
+ def add_all(self, embeddings: list[tuple[str, list[float]]], name: str):
30
+ """Add multiple embeddings."""
31
+ pass
32
+
33
+ @abstractmethod
34
+ def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]:
35
+ """Get the text for the provided embedding."""
36
+ pass
37
+
38
+ @abstractmethod
39
+ def get_all_embeddings(self, name: str):
40
+ """Get all embeddings."""
41
+ pass
42
+
43
+ @abstractmethod
44
+ def clear(self, name: str):
45
+ """Clear the database."""
46
+ pass
47
+
48
+ @abstractmethod
49
+ def been_indexed(self, name: str) -> bool:
50
+ """Check if the database has been indexed."""
51
+ pass
52
+
53
+
54
+ class _IndexStorage(Storage):
55
+ """IndexStorage class."""
56
+
57
+ def __init__(self, cfg: Config):
58
+ """Initialize the storage."""
59
+ self._cfg = cfg
60
+
61
+ def add_all(self, embeddings: list[tuple[str, list[float]]], name):
62
+ """Add multiple embeddings."""
63
+ texts, index = self._load(name)
64
+ ids = np.array([len(texts) + i for i, _ in enumerate(embeddings)])
65
+ texts = pd.concat([texts, pd.DataFrame(
66
+ {'index': len(texts) + i, 'text': text} for i, (text, _) in enumerate(embeddings))])
67
+ array = np.array([emb for text, emb in embeddings])
68
+ index.add_with_ids(array, ids)
69
+ self._save(texts, index, name)
70
+
71
+ def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]:
72
+ """Get the text for the provided embedding."""
73
+ texts, index = self._load(name)
74
+ _, indexs = index.search(np.array([embedding]), limit)
75
+ indexs = [i for i in indexs[0] if i >= 0]
76
+ return [f'paragraph {p}: {t}' for _, p, t in texts.iloc[indexs].values]
77
+
78
+ def get_all_embeddings(self, name: str):
79
+ texts, index = self._load(name)
80
+ texts = texts.text.tolist()
81
+ embeddings = index.reconstruct_n(0, len(texts))
82
+ return list(zip(texts, embeddings))
83
+
84
+ def clear(self, name: str):
85
+ """Clear the database."""
86
+ self._delete(name)
87
+
88
+ def been_indexed(self, name: str) -> bool:
89
+ return os.path.exists(os.path.join(self._cfg.index_path, f'{name}.csv')) and os.path.exists(
90
+ os.path.join(self._cfg.index_path, f'{name}.bin'))
91
+
92
+ def _save(self, texts, index, name: str):
93
+ texts.to_csv(os.path.join(self._cfg.index_path, f'{name}.csv'))
94
+ faiss.write_index(index, os.path.join(self._cfg.index_path, f'{name}.bin'))
95
+
96
+ def _load(self, name: str):
97
+ if self.been_indexed(name):
98
+ texts = pd.read_csv(os.path.join(self._cfg.index_path, f'{name}.csv'))
99
+ index = faiss.read_index(os.path.join(self._cfg.index_path, f'{name}.bin'))
100
+ else:
101
+ texts = pd.DataFrame(columns=['index', 'text'])
102
+ # IDMap2 with Flat
103
+ index = faiss.index_factory(1536, "IDMap2,Flat", faiss.METRIC_INNER_PRODUCT)
104
+ return texts, index
105
+
106
+ def _delete(self, name: str):
107
+ try:
108
+ os.remove(os.path.join(self._cfg.index_path, f'{name}.csv'))
109
+ os.remove(os.path.join(self._cfg.index_path, f'{name}.bin'))
110
+ except FileNotFoundError:
111
+ pass
112
+
113
+
114
+ def singleton(cls):
115
+ instances = {}
116
+
117
+ def get_instance(cfg):
118
+ if cls not in instances:
119
+ instances[cls] = cls(cfg)
120
+ return instances[cls]
121
+
122
+ return get_instance
123
+
124
+
125
+ @singleton
126
+ class _PostgresStorage(Storage):
127
+ """PostgresStorage class."""
128
+
129
+ def __init__(self, cfg: Config):
130
+ """Initialize the storage."""
131
+ self._postgresql = cfg.postgres_url
132
+ self._engine = create_engine(self._postgresql)
133
+ Base.metadata.create_all(self._engine)
134
+ session = sessionmaker(bind=self._engine)
135
+ self._session = session()
136
+
137
+ def add_all(self, embeddings: list[tuple[str, list[float]]], name: str):
138
+ """Add multiple embeddings."""
139
+ data = [self.EmbeddingEntity(text=text, embedding=embedding, name=name) for text, embedding in embeddings]
140
+ self._session.add_all(data)
141
+ self._session.commit()
142
+
143
+ def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]:
144
+ """Get the text for the provided embedding."""
145
+ result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).order_by(
146
+ self.EmbeddingEntity.embedding.cosine_distance(embedding)).limit(limit).all()
147
+ return [f'paragraph {s.id}: {s.text}' for s in result]
148
+
149
+ def get_all_embeddings(self, name: str):
150
+ """Get all embeddings."""
151
+ result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).all()
152
+ return [(s.text, s.embedding) for s in result]
153
+
154
+ def clear(self, name: str):
155
+ """Clear the database."""
156
+ self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).delete()
157
+ self._session.commit()
158
+
159
+ def been_indexed(self, name: str) -> bool:
160
+ return self._session.query(self.EmbeddingEntity).filter_by(name=name).first() is not None
161
+
162
+ def __del__(self):
163
+ """Close the session."""
164
+ self._session.close()
165
+
166
+ class EmbeddingEntity(Base):
167
+ __tablename__ = 'embedding'
168
+ id = Column(Integer, primary_key=True)
169
+ name = Column(String)
170
+ text = Column(String)
171
+ embedding = Column(Vector(1536))
webui.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import xxhash
3
+ from gradio.components import _Keywords
4
+
5
+ from ai import AI
6
+ from config import Config
7
+ from contents import *
8
+ from storage import Storage
9
+
10
+
11
+ def webui(cfg: Config):
12
+ """Run the web UI."""
13
+ Webui(cfg).run()
14
+
15
+
16
+ class Webui:
17
+ def __init__(self, cfg: Config):
18
+ self.cfg = cfg
19
+ self.ai = AI(cfg)
20
+ self.storage = Storage.create_storage(self.cfg) # Initialize storage here
21
+
22
+ def _save_to_storage(self, contents, hash_id):
23
+ print(f"Saving to storage {hash_id}")
24
+ print(f"Contents: \n{contents}")
25
+ self.storage = Storage.create_storage(self.cfg)
26
+ if self.storage.been_indexed(hash_id):
27
+ return 0
28
+ else:
29
+ embeddings, tokens = self.ai.create_embeddings(contents)
30
+ self.storage.add_all(embeddings, hash_id)
31
+ return tokens
32
+
33
+ def _get_hash_id(self, contents):
34
+ return xxhash.xxh3_128_hexdigest('\n'.join(contents))
35
+
36
+ def run(self):
37
+ with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {visibility: hidden}") as demo:
38
+
39
+ hash_id_state = gr.State('dd771cb6c4718ace4c4c596f4792cfdd') # Initialize hash_id_state to 'dd771cb6c4718ace4c4c596f4792cfdd'
40
+ chat_page = gr.Column(visible=True) # Set chat_page to visible by default
41
+
42
+ with chat_page:
43
+ with gr.Row():
44
+ with gr.Column():
45
+ chatbot = gr.Chatbot(label="Kanunla Konuş")
46
+ msg = gr.Textbox(label="Sorunuzu Yazın (Bu deneysel bir projedir, tam ve doğru bilgi için, uzmanlarımıza danışın)")
47
+ submit_box = gr.Button("Kanuna Sor", variant="primary")
48
+
49
+ def respond(message, chat_history, hash_id):
50
+ kw = self.ai.get_keywords(message)
51
+ if len(kw) == 0 or hash_id is None:
52
+ return "", chat_history
53
+ _, kw_ebd = self.ai.create_embedding(kw)
54
+ ctx = self.storage.get_texts(kw_ebd, hash_id)
55
+ print(f"Context: \n{ctx}")
56
+ bot_message = self.ai.completion(message, ctx)
57
+ chat_history.append((message, bot_message))
58
+ return "", chat_history, \
59
+
60
+ def reset():
61
+ return {
62
+ chat_page: gr.update(visible=True),
63
+ chatbot: gr.update(value=[]),
64
+ msg: gr.update(value=""),
65
+ hash_id_state: 'dd771cb6c4718ace4c4c596f4792cfdd',
66
+ }
67
+
68
+ msg.submit(respond, [msg, chatbot, hash_id_state], [msg, chatbot])
69
+ submit_box.click(respond, [msg, chatbot, hash_id_state], [msg, chatbot])
70
+ demo.title = "Kanuna Sor"
71
+ demo.launch(server_port=self.cfg.webui_port, server_name=self.cfg.webui_host, show_api=False)