-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvectordb.py
330 lines (297 loc) · 12.4 KB
/
vectordb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import os
import pickle
from langchain.vectorstores.faiss import FAISS
from langchain.vectorstores.redis import Redis
from langchain.vectorstores.chroma import Chroma
from qdrant_client import QdrantClient
from langchain.vectorstores.qdrant import Qdrant
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
class BaseEngine(object):
def __init__(self, vector_url):
self.vector_url = vector_url
self.embeddings = OpenAIEmbeddings()
self.engine_name = "faiss"
if self.vector_url.startswith("redis://"):
self.engine_name = "redis"
elif self.vector_url.startswith("chroma://"):
self.engine_name = "chroma"
elif self.vector_url.startswith("qdrant://"):
self.engine_name = "qdrant"
elif self.vector_url in ('mock', 'dummy') or self.vector_url is None:
self.engine_name = "mock"
def _ingest(self, **kwargs):
print(f"Using engine: {self.engine_name}")
ingest_func = getattr(self, f"_ingest_{self.engine_name}", None)
if not ingest_func:
raise ValueError(f"Unknown engine: {self.engine_name}")
return ingest_func(**kwargs)
def _load(self, **kwargs):
print(f"Using engine: {self.engine_name}")
load_func = getattr(self, f"_load_{self.engine_name}", None)
if not load_func:
raise ValueError(f"Unknown engine: {self.engine_name}")
return load_func(**kwargs)
class Ingestor(BaseEngine):
def __init__(self, vector_url, docs):
super().__init__(vector_url)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=10,
)
self.docs = text_splitter.split_documents(docs)
def _ingest_mock(self, **kwargs):
while len(self.docs) > 0:
print(f"Total chunks left to process: {len(self.docs)}")
docs = self._pop()
print(f"Processing {len(docs)} chunks...")
print(f"Loaded chunks: processed: {len(docs)}, unprocessed: 0")
return True
def _ingest_redis(self, **kwargs):
db = None
while len(self.docs) > 0:
print(f"Total chunks left to process: {len(self.docs)}")
docs = self._pop()
print(f"Processing {len(docs)} chunks...")
if db is None:
db = Redis.from_documents(docs, self.embeddings, redis_url=self.vector_url, index_name='plivoaskme')
print(f"Loaded chunks: processed: {len(docs)}, unprocessed: 0")
else:
try:
db.add_documents(documents=docs, embedding=self.embeddings)
processed = len(docs)
unprocessed = 0
except ValueError as e:
print(f"ERROR: {e}")
processed = 0
unprocessed = 0
for doc in docs:
try:
db.add_documents(documents=[doc], embedding=self.embeddings)
processed += 1
except ValueError as e:
unprocessed += 1
print(f"ERROR: {e}")
print(f"SKIPPING: {doc}")
print(f"Loaded chunks: processed: {processed}, unprocessed: {unprocessed}")
db = None
return True
def _ingest_qdrant(self, **kwargs):
url = self.vector_url.replace("qdrant://", "") or None
api_key = os.environ.get("QDRANT_API_KEY", None)
if not url:
raise ValueError("Qdrant URL is required")
db = None
while len(self.docs) > 0:
print(f"Total chunks left to process: {len(self.docs)}")
docs = self._pop()
print(f"Processing {len(docs)} chunks...")
if db is None:
db = Qdrant.from_documents(
docs, self.embeddings,
url=url, api_key=api_key,
prefer_grpc=True,
collection_name="plivoaskme",
)
print(f"Loaded chunks: processed: {len(docs)}, unprocessed: 0")
else:
try:
db.add_documents(documents=docs, embedding=self.embeddings)
processed = len(docs)
unprocessed = 0
except ValueError as e:
print(f"ERROR: {e}")
processed = 0
unprocessed = 0
for doc in docs:
try:
db.add_documents(documents=[doc], embedding=self.embeddings)
processed += 1
except ValueError as e:
unprocessed += 1
print(f"ERROR: {e}")
print(f"SKIPPING: {doc}")
print(f"Loaded chunks: processed: {processed}, unprocessed: {unprocessed}")
db = None
return True
def _ingest_chroma(self, **kwargs):
directory = self.vector_url.replace("chroma://", "") or None
if not directory:
raise ValueError("Chroma directory is required")
try:
os.makedirs(directory)
except:
pass
db = None
while len(self.docs) > 0:
print(f"Total chunks left to process: {len(self.docs)}")
docs = self._pop()
print(f"Processing {len(docs)} chunks...")
if db is None:
db = Chroma.from_documents(chunks=docs, embedding=self.embeddings,
persist_directory=directory)
print(f"Loaded chunks: processed: {len(docs)}, unprocessed: 0")
else:
try:
db.add_documents(documents=docs, embedding=self.embeddings)
processed = len(docs)
unprocessed = 0
except ValueError as e:
print(f"ERROR: {e}")
processed = 0
unprocessed = 0
for doc in docs:
try:
db.add_documents(documents=[doc], embedding=self.embeddings)
processed += 1
except ValueError as e:
unprocessed += 1
print(f"ERROR: {e}")
print(f"SKIPPING: {doc}")
print(f"Loaded chunks: processed: {processed}, unprocessed: {unprocessed}")
if db is not None:
db.persist()
db = None
return True
def _retry_ingest_faiss(self, docs):
print(f"DEBUG: _retry_ingest_faiss start processing {len(docs)}")
# re-init embeddings
self.embeddings = OpenAIEmbeddings()
db = None
_db = None
prev_doc = None
processed = 0
for doc in docs:
try:
_db = FAISS.from_documents([doc], self.embeddings)
except ValueError as e:
print(f"# ERROR: FAISS.from_documents: {e}")
print(f"# ACTION: skipping document CURRENT_DOC")
print(f"# CURRENT_DOC:\n{doc}\n\n")
print(f"# PREVIOUS_DOC:\n{prev_doc}\n\n")
print("#"*10)
continue
try:
if db is None:
db = _db
else:
db.merge_from(_db)
processed += 1
except Exception as e:
print(f"# ERROR db.merge_from: {e}")
print(f"# ACTION: skipping document CURRENT_DOC")
print(f"# CURRENT_DOC:\n{doc}\n\n")
print(f"# PREVIOUS_DOC:\n{prev_doc}\n\n")
print("#"*10)
continue
prev_doc = doc
unnprocessed = len(docs) - processed
print(f"DEBUG: _retry_ingest_faiss done: processed:{processed}, unprocessed:{unnprocessed}")
return db
def _ingest_faiss(self, **kwargs):
overwrite = kwargs.get("overwrite", True)
ingest_size = kwargs.get("ingest_size", 500)
idx = 1
print(f"Total chunks to process: {len(self.docs)}")
while len(self.docs) > 0:
print(f"Total chunks left to process: {len(self.docs)}")
docs = self._pop(size=ingest_size)
print(f"Processing {len(docs)} chunks...")
try:
db = FAISS.from_documents(docs, self.embeddings)
except ValueError as e:
print(f"ERROR FAISS.from_documents: {e}")
db = self._retry_ingest_faiss(docs)
vector_url = self.vector_url + f".{idx}"
print(f"Saving {len(docs)} chunks into FAISS {vector_url}")
with open(vector_url, "wb") as f:
pickle.dump(db, f)
idx += 1
print(f"Saved {len(docs)} chunks into FAISS {vector_url}")
print(f"Processed {len(docs)} chunks...")
orig_vector_url = self.vector_url + '.1'
if not os.path.exists(orig_vector_url):
print(f"No FAISS file {orig_vector_url} created, stopping...")
return False
db = Loader.load(orig_vector_url)
for i in range(2, idx):
vector_url = self.vector_url + f".{i}"
if not os.path.exists(vector_url):
print(f"No FAISS file {vector_url} created, skipping...")
continue
print(f"Merging {vector_url} into {orig_vector_url}")
db.merge_from(Loader.load(vector_url))
os.remove(vector_url)
print(f"Merged {vector_url} into {orig_vector_url}")
try: os.remove(orig_vector_url)
except: pass
if overwrite is True or not os.path.exists(self.vector_url):
print(f"New FAISS file created {self.vector_url}, saving...")
with open(self.vector_url, "wb") as f:
pickle.dump(db, f)
print(f"Saved data into {self.vector_url}")
return True
else:
print(f"Found existing FAISS file {self.vector_url}, merging...")
src_db = Loader.load(self.vector_url)
src_db.merge_from(db)
with open(self.vector_url, "wb") as f:
pickle.dump(src_db, f)
print(f"Merged data into {self.vector_url}")
return True
def run(self, **kwargs):
return self._ingest(**kwargs)
def _pop(self, size=500):
docs = []
i = 0
while i < size:
try:
doc = self.docs.pop()
docs.append(doc)
i += 1
except IndexError:
break
return docs
@classmethod
def ingest(cls, vector_url, docs, **kwargs):
return cls(vector_url, docs).run(**kwargs)
class Loader(BaseEngine):
def __init__(self, vector_url):
super().__init__(vector_url)
def _load_redis(self, **kwargs):
db = Redis.from_existing_index(self.embeddings,
redis_url=self.vector_url,
index_name='plivoaskme')
return db
def _load_qdrant(self, **kwargs):
url = self.vector_url.replace("qdrant://", "") or None
api_key = os.environ.get("QDRANT_API_KEY", None)
if not url:
raise Exception(f"Qdrant URL not found: {url}")
client = QdrantClient(
url=url, api_key=api_key,
prefer_grpc=True
)
db = Qdrant(
client=client, collection_name="plivoaskme",
embeddings=self.embeddings
)
return db
def _load_chroma(self, **kwargs):
directory = self.vector_url.replace("chroma://", "") or None
if not directory:
raise Exception(f"Chroma directory not found: {directory}")
db = Chroma(persist_directory=self.vector_url,
embedding_function=self.embeddings)
return db
def _load_faiss(self, **kwargs):
if not os.path.exists(self.vector_url):
raise Exception(f"FAISS file not found: {self.vector_url}")
with open(self.vector_url, "rb") as f:
db = pickle.load(f)
return db
def run(self, **kwargs):
return self._load(**kwargs)
@classmethod
def load(cls, vector_url, **kwargs):
return cls(vector_url).run(**kwargs)