Sale_Agent_Data_Indexing / data_indexing.py
anhkhoiphan's picture
Bổ sung data processing
a519263
raw
history blame
55.7 kB
import io
import os
import requests
import sys
# import tempfile
# import time
from typing import List, Dict, Tuple, Any, Optional
import uuid
from PIL import Image
from FlagEmbedding import BGEM3FlagModel
import gradio as gr
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
import qdrant_client
from qdrant_client.http.models import Modifier, Distance, SparseVectorParams, VectorParams, SparseIndexParams
import torch
from transformers import EfficientNetModel, AutoImageProcessor
from pymongo import MongoClient
import contextlib
from config import (
QDRANT_COLLECTION_NAME_SPCHIEUSANG,
QDRANT_COLLECTION_NAME_SPCHUYENDUNG,
QDRANT_COLLECTION_NAME_SPPHICHNUOC,
QDRANT_COLLECTION_NAME_SPTHIETBIDIEN,
QDRANT_COLLECTION_NAME_SPNHATHONGMINH,
QDRANT_COLLECTION_NAME_GPNHATHONGMINH,
QDRANT_COLLECTION_NAME_GPHOCDUONG,
QDRANT_COLLECTION_NAME_GPNGUNGHIEP,
QDRANT_COLLECTION_NAME_GPCANHQUAN,
QDRANT_COLLECTION_NAME_GPNLMT,
QDRANT_COLLECTION_NAME_GPNNCNC,
QDRANT_COLLECTION_NAME_GPDUONGPHO,
QDRANT_COLLECTION_NAME_GPVPCS,
QDRANT_COLLECTION_NAME_GPNMCN,
QDRANT_COLLECTION_NAME_GPNOXH,
IMAGE_EMBEDDING_SIZE,
TEXT_EMBEDDING_SIZE,
IMAGE_EMBEDDING_MODEL,
TEXT_EMBEDDING_MODEL,
MONGODB_URI,
QDRANT_HOST,
QDRANT_API_KEY
)
from data_helper import *
# from src.utils.helper import client
client = qdrant_client.QdrantClient(
url=QDRANT_HOST,
api_key=QDRANT_API_KEY,
timeout=300.0
)
"""=================SETTINGS========================"""
device = torch.device(
"cuda" if torch.cuda.is_available() else
"mps" if torch.mps.is_available() else
"cpu"
)
product_vectors_config = {
"product": qdrant_client.http.models.VectorParams(
size=TEXT_EMBEDDING_SIZE,
distance=Distance.COSINE
),
"image": qdrant_client.http.models.VectorParams(
size=IMAGE_EMBEDDING_SIZE,
distance=Distance.COSINE
),
"product_bgem3_dense": qdrant_client.http.models.VectorParams(
size=1024,
distance=Distance.COSINE,
)
}
sparse_vectors_config={
"product_bgem3_sparse": SparseVectorParams(
index=SparseIndexParams(on_disk=False),
modifier = Modifier.IDF
)
}
product_collections = [
QDRANT_COLLECTION_NAME_SPCHIEUSANG,
QDRANT_COLLECTION_NAME_SPCHUYENDUNG,
QDRANT_COLLECTION_NAME_SPPHICHNUOC,
QDRANT_COLLECTION_NAME_SPTHIETBIDIEN,
QDRANT_COLLECTION_NAME_SPNHATHONGMINH
]
product_types = [
"chieu_sang",
"chuyen_dung",
"phich_nuoc",
"thiet_bi_dien",
"nha_thong_minh"
]
# MongoDB collections mapping for products
mongodb_product_collections = {
"chieu_sang": "sp_chieu_sang",
"chuyen_dung": "sp_chuyen_dung",
"phich_nuoc": "sp_phich_nuoc",
"thiet_bi_dien": "sp_thiet_bi_dien",
"nha_thong_minh": "sp_nha_thong_minh"
}
solution_collections = [
QDRANT_COLLECTION_NAME_GPCANHQUAN,
QDRANT_COLLECTION_NAME_GPDUONGPHO,
QDRANT_COLLECTION_NAME_GPHOCDUONG,
QDRANT_COLLECTION_NAME_GPNHATHONGMINH,
QDRANT_COLLECTION_NAME_GPNGUNGHIEP,
QDRANT_COLLECTION_NAME_GPNLMT,
QDRANT_COLLECTION_NAME_GPNNCNC,
QDRANT_COLLECTION_NAME_GPVPCS,
QDRANT_COLLECTION_NAME_GPNMCN,
QDRANT_COLLECTION_NAME_GPNOXH
]
solution_types = [
"canh_quan",
"duong_pho",
"hoc_duong",
"nha_thong_minh",
"ngu_nghiep",
"nlmt",
"nong_nghiep_cnc",
"van_phong_cong_so",
"nha_may_cong_nghiep",
"nha_o_xa_hoi"
]
# MongoDB collections mapping for solutions
mongodb_solution_collections = {
"canh_quan": "gp_canh_quan",
"duong_pho": "gp_duong_pho",
"hoc_duong": "gp_hoc_duong",
"nha_thong_minh": "gp_nha_thong_minh",
"ngu_nghiep": "gp_ngu_nghiep",
"nlmt": "gp_he_thong_dien_nlmt",
"nong_nghiep_cnc": "gp_nong_nghiep_cnc",
"van_phong_cong_so": "gp_van_phong_cong_so",
"nha_may_cong_nghiep": "gp_nha_may_cong_nghiep",
"nha_o_xa_hoi": "gp_nha_o_xa_hoi"
}
class OutputCapture:
"""Context manager to capture stdout and stderr"""
def __init__(self):
self.buffer = io.StringIO()
self.old_stdout = None
self.old_stderr = None
def __enter__(self):
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
sys.stdout = self.buffer
sys.stderr = self.buffer
return self.buffer
def __exit__(self, *args):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
def getvalue(self):
return self.buffer.getvalue()
"""=================MONGODB CONNECTION========================"""
class MongoDBConnection:
def __init__(self, connection_string: str = None, db_name: str = "product_database"):
"""
Initialize MongoDB connection
Args:
connection_string: MongoDB Atlas connection string
db_name: Database name
"""
self.connection_string = MONGODB_URI if connection_string is None else connection_string
self.db_name = db_name
self.client = None
self.db = None
def connect(self):
"""Establish connection to MongoDB"""
try:
self.client = MongoClient(self.connection_string)
self.db = self.client[self.db_name]
# Test connection
self.client.admin.command('ping')
print(f"✅ Connected to MongoDB: {self.db_name}")
return True
except Exception as e:
print(f"❌ Failed to connect to MongoDB: {e}")
return False
def get_collection_data(self, collection_name: str) -> List[Dict]:
"""
Retrieve all documents from a collection
Args:
collection_name: Name of the MongoDB collection
Returns:
List of documents
"""
try:
collection = self.db[collection_name]
data = list(collection.find({}))
# Convert ObjectId to string
for item in data:
if '_id' in item:
item['_id'] = str(item['_id'])
print(f"✅ Retrieved {len(data)} documents from {collection_name}")
return data
except Exception as e:
print(f"❌ Error retrieving data from {collection_name}: {e}")
return []
def close(self):
"""Close MongoDB connection"""
if self.client:
self.client.close()
print("✅ MongoDB connection closed")
"""=================CLASS EMBEDDING========================"""
class DataEmbedding:
def __init__(self):
pass
def embed_text_batch(self, contents: List[str], batch_size: int = 32, hybrid_mode: bool = False) -> List[Optional[torch.Tensor]]:
"""Create text embeddings using HuggingFaceEmbeddings (768 dimensions), and optionally BGEM3 (1024 dimensions) in batches."""
normal_embeddings, bgem3_dense_embeddings, bgem3_sparse_embeddings = [], [], []
# Filter out empty contents and keep track of original indices
valid_contents = []
valid_indices = []
for i, content in enumerate(contents):
if content:
valid_contents.append(content)
valid_indices.append(i)
if not valid_contents:
return [None] * len(contents)
try:
text_embedding_model = HuggingFaceEmbeddings(
model_name=TEXT_EMBEDDING_MODEL,
model_kwargs={'device': device},
encode_kwargs={'normalize_embeddings': True}
)
if hybrid_mode:
hybrid_embedding_model = BGEM3FlagModel(
"BAAI/bge-m3",
use_fp16=True,
devices=str(device)
)
for i in range(0, len(valid_contents), batch_size):
batch_contents = valid_contents[i:i+batch_size]
bgem3_dense_embeddings_list, bgem3_sparse_embeddings_list = [], []
if hybrid_mode:
bgem3_embeddings = hybrid_embedding_model.encode(
sentences=batch_contents,
return_dense=True,
return_sparse=True
)
bgem3_dense_embeddings_list = bgem3_embeddings['dense_vecs']
bgem3_sparse_embeddings_list = bgem3_embeddings['lexical_weights']
bgem3_dense_embeddings.extend([
torch.tensor(emb, dtype=torch.float32)
for emb in bgem3_dense_embeddings_list
])
bgem3_sparse_embeddings.extend(bgem3_sparse_embeddings_list)
normal_embeddings_list = text_embedding_model.embed_documents(batch_contents)
normal_embeddings.extend([torch.tensor(emb, dtype=torch.float32) for emb in normal_embeddings_list])
# Map back to original order
result = [None] * len(contents)
for i, valid_idx in enumerate(valid_indices):
if hybrid_mode:
result[valid_idx] = (normal_embeddings[i], bgem3_dense_embeddings[i], bgem3_sparse_embeddings[i])
else:
result[valid_idx] = (normal_embeddings[i], [], [])
return result
except Exception as e:
print(f"❌ Error in batch text embedding: {str(e)[:100]}...")
return []
def embed_images_batch(self, image_urls: List[str], batch_size: int = 32) -> List[Optional[torch.Tensor]]:
"""Create image embeddings in batches."""
all_embeddings: List[Optional[torch.Tensor]] = [None] * len(image_urls)
# Create a list of images and their original indices that need processing
images_to_process: List[Tuple[Any, int]] = []
for i, url in enumerate(image_urls):
if url:
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
image = Image.open(io.BytesIO(response.content)).convert('RGB')
images_to_process.append((image, i))
except requests.exceptions.RequestException as e:
print(f"❌ HTTP error for url {url}: {e}")
pass
except Exception as e:
print(f"❌ Error loading image {url}: {e}")
pass
if not images_to_process:
return all_embeddings
image_processor = AutoImageProcessor.from_pretrained(IMAGE_EMBEDDING_MODEL)
image_embedding_model = EfficientNetModel.from_pretrained(IMAGE_EMBEDDING_MODEL).to(device)
# Process images in batches
for i in range(0, len(images_to_process), batch_size):
batch_data = images_to_process[i:i+batch_size]
batch_images = [d[0] for d in batch_data]
batch_indices = [d[1] for d in batch_data]
try:
inputs = image_processor(images=batch_images, return_tensors="pt").to(device)
with torch.no_grad():
outputs = image_embedding_model(**inputs)
embeddings = outputs.pooler_output.squeeze()
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
for j, embedding in enumerate(normalized_embeddings):
original_index = batch_indices[j]
all_embeddings[original_index] = embedding.squeeze()
except Exception as e:
print(f"❌ Error embedding image batch: {e}")
pass
return all_embeddings
class ProductEmbedding(DataEmbedding):
def run_embedding(self, product_type: str, mongodb_conn: MongoDBConnection,
batch_size: int = 32, hybrid_mode: bool = False):
"""
Generate embeddings for a specific product type from MongoDB
Args:
product_type: Type of product
mongodb_conn: MongoDB connection object
batch_size: Batch size for processing
hybrid_mode: Whether to use hybrid text embedding (BGEM3)
"""
embeddings = []
processed_docs = self.prepare_docs(
product_type=product_type,
mongodb_conn=mongodb_conn
)
# Batch text embedding for speed
text_contents = [doc.page_content for doc in processed_docs]
text_embeddings = self.embed_text_batch(text_contents, batch_size, hybrid_mode)
# Batch image embedding
image_urls = [doc.metadata.get("image_url") for doc in processed_docs]
image_embeddings = self.embed_images_batch(image_urls)
# Create embeddings with optimized structure creation
for i, doc in enumerate(processed_docs):
if i < len(text_embeddings) and text_embeddings[i] is not None:
normal_text_embedding, bgem3_dense_text_embedding, bgem3_sparse_text_embedding = text_embeddings[i]
else:
normal_text_embedding, bgem3_dense_text_embedding, bgem3_sparse_text_embedding = None, None, None
image_embedding = image_embeddings[i] if i < len(image_embeddings) else None
# Create vectors dict - ensure proper format
vectors = {
"product": normal_text_embedding.tolist() if normal_text_embedding is not None else [0.0] * TEXT_EMBEDDING_SIZE,
"product_bgem3_dense": bgem3_dense_text_embedding.tolist() if bgem3_dense_text_embedding is not None else [0.0] * 1024,
"image": image_embedding.tolist() if image_embedding is not None else [0.0] * IMAGE_EMBEDDING_SIZE
}
if bgem3_sparse_text_embedding is not None and bgem3_sparse_text_embedding:
sparse_vectors = {
"product_bgem3_sparse": {
"indices": list(bgem3_sparse_text_embedding.keys()),
"values": [float(v) for v in bgem3_sparse_text_embedding.values()]
}
}
else:
sparse_vectors = {"product_sparse": {"indices": [], "values": []}}
# Create payload with optimized metadata processing
payload = {
"product": doc.page_content,
"metadata": {key: value for key, value in doc.metadata.items()}
}
# Create and append point
embeddings.append({
"point_id": str(uuid.uuid4()),
"vectors": vectors,
"sparse_vectors": sparse_vectors,
"payload": payload
})
print(f"Generated {len(embeddings)} embeddings for {product_type}")
return embeddings
def prepare_docs(self, product_type: str, mongodb_conn: MongoDBConnection):
"""
Prepare documents from MongoDB
Args:
product_type: Type of product
mongodb_conn: MongoDB connection object
"""
if not mongodb_conn or mongodb_conn.db is None:
raise ValueError("MongoDB connection not established")
collection_name = mongodb_product_collections.get(product_type)
if not collection_name:
raise ValueError(f"No MongoDB collection mapping for product type: {product_type}")
data = mongodb_conn.get_collection_data(collection_name)
print(f"🗄️ Loaded data from MongoDB collection: {collection_name}")
docs = []
EXCLUDE_FROM_FLATTENING = {"tags"}
for item in data:
content = self.create_content(item)
metadata = self.extract_metadata(item, product_type)
# Create a flat metadata structure for indexing
flat_metadata = {**metadata}
for key, value in metadata.items():
if isinstance(value, dict) and key not in EXCLUDE_FROM_FLATTENING:
flat_metadata.update({f"{key}_{sub_key}": sub_value for sub_key, sub_value in value.items()})
doc = Document(page_content=content, metadata=flat_metadata)
docs.append(doc)
print(f"Prepared {len(docs)} documents")
return docs
def create_content(self, item: Dict) -> str:
"""Tạo document content cho sản phẩm"""
product_name = item.get("Tên sản phẩm", "")
model = item.get("Mã Sản Phẩm", "")
summary_specs = item.get("Tóm tắt TSKT", "")
summary_advantages = item.get("Tóm tắt ưu điểm, tính năng", "")
specs = item.get("Thông số kỹ thuật", "")
advantages = item.get("Nội dung Ưu điểm SP\n(- File word/Excel\n- Đặt tên file theo mã SAP)", "")
instruction = item.get("HDSD", "")
content = (
f"# Tên sản phẩm: {product_name}\n\n"
f"## Mã sản phẩm: {model}\n\n"
f"## Tóm tắt TSKT\n{summary_specs}\n\n"
f"### Thông số kỹ thuật chi tiết\n{specs}\n\n"
f"## Tóm tắt ưu điểm & tính năng\n{summary_advantages}\n\n"
f"### Ưu điểm & tính năng chi tiết\n{advantages}\n"
f"## Hướng dẫn sử dụng: \n{instruction}\n"
)
return content
def extract_metadata(self, item: Dict, product_type: str) -> Dict:
"""Extract metadata from a product item"""
additional_info = ProductEmbedding.process_additional_metadata(item, product_type)
tags = item.get("Tags", {})
common_metadata = {
"prod_id": item.get("Product_ID", None),
"ten_san_pham": item.get("Tên sản phẩm", ""),
"model": item.get("Mã Sản Phẩm", ""),
"danh_muc_l1": item.get("category 1", ""),
"danh_muc_l2": item.get("category 2", ""),
"danh_muc_l3": item.get("category 3", ""),
"url": str(item.get("Link sản phẩm", "")).strip(),
"image_url": item.get("Link ảnh sản phẩm"),
"buy_url": item.get("Link mua hàng online", ""),
"gia": item.get("Giá", ""),
"tags": tags,
**tags,
**additional_info
}
return common_metadata
@staticmethod
def process_additional_metadata(item: Dict[str, Any], product_type) -> Dict[str, Any]:
"""Process an item and extract additional information"""
tags = item.get("Tags", {})
spec_text = item.get("Tóm tắt TSKT", "")
model = item.get("Mã Sản Phẩm", "")
prod_name = item.get("Tên sản phẩm", "")
additional_info = {}
# Extract cong_suat
if "cong_suat" not in tags.keys() or tags["cong_suat"] == "":
power = extract_power(spec_text)
if power is not None:
additional_info["cong_suat"] = power
# Extract based on product type
if product_type == "phich_nuoc":
pass
elif product_type == "chieu_sang":
ceiling_hole_diameter = extract_ceiling_hole_diameter2(spec_text)
if ceiling_hole_diameter is not None:
additional_info["duong_kinh_lo_khoet_tran"] = ceiling_hole_diameter
tinh_nang = extract_tinh_nang(model, prod_name)
if tinh_nang is not None:
additional_info["tinh_nang"] = tinh_nang
elif product_type == "chuyen_dung":
he_thong_hoa_luoi_pha = extract_he_thong_hoa_luoi_pha(prod_name)
if he_thong_hoa_luoi_pha is not None:
additional_info["he_thong_hoa_luoi_pha"] = he_thong_hoa_luoi_pha
elif product_type == "thiet_bi_dien":
dong_danh_dinh = extract_dong_danh_dinh(spec_text)
if dong_danh_dinh is not None:
additional_info["dong_danh_dinh"] = dong_danh_dinh
elif product_type == "nha_thong_minh":
cable_length = extract_cable_length(spec_text)
if cable_length is not None:
additional_info["chieu_dai_day"] = cable_length
plugs_max_current = extract_plugs_max_current(spec_text)
if plugs_max_current is not None:
additional_info["dong_dien_o_cam_toi_da"] = plugs_max_current
voltage = extract_voltage(model)
if voltage is not None:
additional_info["dien_ap"] = voltage
return additional_info
class SolutionEmbedding(DataEmbedding):
def run_embedding(self, solution_type: str, mongodb_conn: MongoDBConnection, batch_size: int = 32):
"""Generate embeddings for a specific solution type from MongoDB"""
embeddings = []
processed_docs, docs_to_embed = self.prepare_docs(solution_type, mongodb_conn)
embedding_contents = [doc.page_content for doc in docs_to_embed]
text_embeddings = self.embed_text_batch(embedding_contents, batch_size)
# Create embeddings with optimized structure creation
for i, doc in enumerate(processed_docs):
embedding_tuple = text_embeddings[i] if i < len(text_embeddings) else None
text_embedding = embedding_tuple[0] if embedding_tuple is not None else None
# Create payload with optimized metadata processing
payload = {
"content": doc.page_content,
"metadata": {key: value for key, value in doc.metadata.items()}
}
# Create and append point
embeddings.append({
"point_id": str(uuid.uuid4()),
"vectors": text_embedding.tolist() if text_embedding is not None else [0.0] * 768,
"payload": payload
})
print(f"Generated {len(embeddings)} embeddings for {solution_type}")
return embeddings
def prepare_docs(self, solution_type: str, mongodb_conn: MongoDBConnection):
"""
Prepare documents from MongoDB
Args:
solution_type: Type of solution
mongodb_conn: MongoDB connection object
"""
if not mongodb_conn or mongodb_conn.db is None:
raise ValueError("MongoDB connection not established")
collection_name = mongodb_solution_collections.get(solution_type)
if not collection_name:
raise ValueError(f"No MongoDB collection mapping for solution type: {solution_type}")
data = mongodb_conn.get_collection_data(collection_name)
print(f"🗄️ Loaded solution data from MongoDB collection: {collection_name}")
docs = []
docs_to_embed = []
for item in data:
# Assuming the MongoDB document structure matches the JSON structure
for key, val in item.items():
if key in ["_id", "san_pham"]: # Skip MongoDB _id and san_pham
continue
if isinstance(val, list):
for d in val:
page_content = ". ".join([f"{k}: {v}" for k, v in d.items()])
docs.append(
Document(
page_content=page_content,
metadata={"category": key}
)
)
if key != "faq":
docs_to_embed.append(
Document(
page_content=page_content,
metadata={"category": key}
)
)
else:
page_content = f"Câu hỏi: {d.get('Câu hỏi', '')}"
docs_to_embed.append(
Document(
page_content=page_content,
metadata={"category": key}
)
)
elif isinstance(val, dict):
for k, v in val.items():
docs_to_embed.append(Document(page_content=f"{k}: {v}", metadata={"category": key}))
docs.append(Document(page_content=f"{k}: {v}", metadata={"category": key}))
print(f"Prepared {len(docs)} documents")
return docs, docs_to_embed
"""=================CLASS INDEXING========================"""
class ProductIndexing:
def __init__(self, vector_db_client=client):
super().__init__()
self.client = vector_db_client
self.mongodb_conn = None
def setup_mongodb(self, connection_string: str = None):
"""Setup MongoDB connection"""
self.mongodb_conn = MongoDBConnection(connection_string)
return self.mongodb_conn.connect()
def index(
self,
embeddings: List[Dict],
collection_name: str,
batch_size: int = 100
):
"""Index embeddings to a Qdrant collection in batches"""
total_docs = len(embeddings)
success_count = 0
error_count = 0
print(f"Adding {total_docs} multimodal documents to '{collection_name}'...")
for i in range(0, total_docs, batch_size):
batch = embeddings[i:i+batch_size]
points = []
try:
for embedding_data in batch:
combined_vectors = embedding_data["vectors"].copy()
combined_vectors.update(embedding_data["sparse_vectors"])
point = qdrant_client.http.models.PointStruct(
id=embedding_data["point_id"],
vector=combined_vectors,
payload=embedding_data["payload"]
)
points.append(point)
if points:
self.client.upsert(collection_name=collection_name, points=points)
success_count += len(batch)
text_count = sum(1 for p in points if any(v != 0 for v in p.vector["product"]))
image_count = sum(1 for p in points if any(v != 0 for v in p.vector["image"]))
print(f"✅ Batch {i//batch_size + 1}: {len(batch)} docs | {text_count} product | {image_count} images")
else:
print(f"⚠️ Batch {i//batch_size + 1}: No valid points to upload")
except Exception as e:
error_count += len(batch)
print(f"❌ Batch {i//batch_size + 1} failed: {e}")
print(f"\n📊 Final Results:")
print(f" ✅ Successful: {success_count}")
print(f" ❌ Failed: {error_count}")
print(f" 📈 Success Rate: {success_count/(success_count+error_count)*100:.1f}%")
def run_indexing(self, reload: bool = True, hybrid_mode: bool = True):
"""
Index all product data from MongoDB into Qdrant collections.
Args:
reload: Whether to recreate collections
hybrid_mode: Whether to use hybrid text embedding (BGEM3)
"""
with OutputCapture() as output:
try:
if reload:
try:
for collection in product_collections:
self.client.recreate_collection(
collection_name=collection,
vectors_config=product_vectors_config,
sparse_vectors_config=sparse_vectors_config
)
print("✅ All product collections recreated.")
except Exception as e:
print(f"❌ Error while recreating collections: {e}")
return output.getvalue()
# Setup MongoDB connection
if not self.mongodb_conn:
if not self.setup_mongodb():
print("❌ Failed to connect to MongoDB. Aborting indexing.")
return output.getvalue()
# Create embedding processor
embed_object = ProductEmbedding()
for collection, product_type in zip(product_collections, product_types):
print(f"\n{'='*60}")
print(f"🔄 Processing {product_type} data from MongoDB...")
print(f"{'='*60}")
try:
# Generate embeddings for specific product type
embeddings = embed_object.run_embedding(
product_type=product_type,
mongodb_conn=self.mongodb_conn,
hybrid_mode=hybrid_mode
)
# Index embeddings to specific collection
self.index(embeddings, collection)
self._create_payload_indexes_for_product_type(product_type, collection)
print(f"✅ Completed indexing for {product_type}")
except Exception as e:
print(f"❌ Error indexing {product_type}: {e}")
import traceback
print(traceback.format_exc())
# Close MongoDB connection
if self.mongodb_conn:
self.mongodb_conn.close()
self.mongodb_conn = None
print(f"\n{'='*60}")
print("🎉 All indexing completed!")
print(f"{'='*60}")
except Exception as e:
print(f"❌ Fatal error during indexing: {e}")
import traceback
print(traceback.format_exc())
return output.getvalue()
def indexing_single_product_type(self, product_type: str, collection_name: str,
hybrid_mode: bool = True) -> str:
"""
Indexing a single product group into its Qdrant collection from MongoDB
Args:
product_type: Type of product
collection_name: Qdrant collection name
hybrid_mode: Whether to use hybrid text embedding (BGEM3)
"""
with OutputCapture() as output:
try:
print(f"{'='*60}")
print(f"🚀 Starting indexing for {product_type}")
print(f"{'='*60}\n")
self.client.recreate_collection(
collection_name=collection_name,
vectors_config=product_vectors_config,
sparse_vectors_config=sparse_vectors_config
)
print(f"✅ Collection {collection_name} created\n")
# Setup MongoDB connection
if not self.mongodb_conn:
if not self.setup_mongodb():
print("❌ Failed to connect to MongoDB")
return output.getvalue()
# Create embedding processor
embed_object = ProductEmbedding()
print(f"🔄 Processing {product_type} data from MongoDB...")
embeddings = embed_object.run_embedding(
product_type=product_type,
mongodb_conn=self.mongodb_conn,
hybrid_mode=hybrid_mode
)
print(f"\n📊 Indexing to Qdrant...")
self.index(embeddings, collection_name)
print(f"\n🔍 Creating payload indexes...")
self._create_payload_indexes_for_product_type(product_type, collection_name)
# Close MongoDB connection
if self.mongodb_conn:
self.mongodb_conn.close()
self.mongodb_conn = None
print(f"\n{'='*60}")
print(f"✅ Successfully completed indexing for {product_type}")
print(f"{'='*60}")
except Exception as e:
print(f"❌ Error while indexing product type {product_type}: {e}")
import traceback
print(traceback.format_exc())
return output.getvalue()
def _create_payload_indexes_for_product_type(self, product_type: str, collection_name: str):
"""Create payload indexes based on product type field schemas"""
print(f"🔍 Creating payload indexes for {product_type}...")
try:
# Common fields across all product types
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.danh_muc_l2",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.danh_muc_l3",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.gia",
field_schema=qdrant_client.http.models.IntegerIndexParams(type="integer")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.cong_suat",
field_schema=qdrant_client.http.models.FloatIndexParams(type="float")
)
# Product-specific fields
if product_type == "phich_nuoc":
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.dung_tich",
field_schema=qdrant_client.http.models.FloatIndexParams(type="float")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.chat_lieu",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.tinh_nang",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
elif product_type == "chieu_sang":
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.kich_thuoc",
field_schema=qdrant_client.http.models.FloatIndexParams(type="float")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.duong_kinh_lo_khoet_tran",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.tinh_nang",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
elif product_type == "chuyen_dung":
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.nhiet_do_mau",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.dien_ap",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.cong_nghe_led",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.loai_den",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.he_thong_hoa_luoi",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
elif product_type == "thiet_bi_dien":
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.dong_danh_dinh",
field_schema=qdrant_client.http.models.FloatIndexParams(type="float")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.anh_sang",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.so_hat",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.so_cuc",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.modules",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.doi_tuong",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.cong_nghe",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.loai_den",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.san_pham",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
elif product_type == "nha_thong_minh":
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.chieu_dai_day",
field_schema=qdrant_client.http.models.FloatIndexParams(type="float")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.lo_khoet_tran",
field_schema=qdrant_client.http.models.IntegerIndexParams(type="integer")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.nut_bam",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.dong_dien_o_cam_toi_da",
field_schema=qdrant_client.http.models.IntegerIndexParams(type="integer")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.dien_ap",
field_schema=qdrant_client.http.models.IntegerIndexParams(type="integer")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.hinh_dang",
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.tinh_nang",
field_schema=qdrant_client.http.models.TextIndexParams(type="text")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.goc_chieu",
field_schema=qdrant_client.http.models.TextIndexParams(type="text")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.combo",
field_schema=qdrant_client.http.models.TextIndexParams(type="text")
)
self.client.create_payload_index(
collection_name=collection_name,
field_name="metadata.anh_sang",
field_schema=qdrant_client.http.models.TextIndexParams(type="text")
)
print(f"✅ All payload indexes created for {product_type}")
except Exception as e:
print(f"❌ Error creating payload indexes for {product_type}: {e}")
class SolutionIndexing:
def __init__(self, vector_db_client=client):
super().__init__()
self.client = vector_db_client
self.mongodb_conn = None
def setup_mongodb(self, connection_string: str = None):
"""Setup MongoDB connection"""
self.mongodb_conn = MongoDBConnection(connection_string)
return self.mongodb_conn.connect()
def index(
self,
embeddings: List[Dict],
collection_name: str,
batch_size: int = 10
):
"""Index embeddings to a Qdrant collection in batches"""
total_docs = len(embeddings)
success_count = 0
error_count = 0
print(f"Adding {total_docs} solution documents to '{collection_name}'...")
for i in range(0, total_docs, batch_size):
batch = embeddings[i:i+batch_size]
points = []
try:
for embedding_data in batch:
# Create Qdrant point from embedding data
point = qdrant_client.http.models.PointStruct(
id=embedding_data["point_id"],
vector=embedding_data["vectors"],
payload=embedding_data["payload"]
)
points.append(point)
# Upload batch to Qdrant
if points:
self.client.upsert(collection_name=collection_name, points=points)
success_count += len(batch)
# Count successful embeddings
text_count = sum(1 for p in points if any(v != 0 for v in p.vector))
print(f"✅ Batch {i//batch_size + 1}: {len(batch)} docs | {text_count} contents")
else:
print(f"⚠️ Batch {i//batch_size + 1}: No valid points to upload")
except Exception as e:
error_count += len(batch)
print(f"❌ Batch {i//batch_size + 1} failed: {e}")
print(f"\n📊 Final Results:")
print(f" ✅ Successful: {success_count}")
print(f" ❌ Failed: {error_count}")
print(f" 📈 Success Rate: {success_count/(success_count+error_count)*100:.1f}%")
def run_indexing(self, reload: bool = True):
"""Index all solution data from MongoDB into Qdrant collections."""
with OutputCapture() as output:
try:
if reload:
try:
for collection in solution_collections:
self.client.recreate_collection(
collection_name=collection,
vectors_config=qdrant_client.http.models.VectorParams(
size=768,
distance=qdrant_client.http.models.Distance.COSINE,
)
)
print("✅ All solution collections recreated.")
except Exception as e:
print(f"❌ Error while recreating collections: {e}")
return output.getvalue()
# Setup MongoDB connection
if not self.mongodb_conn:
if not self.setup_mongodb():
print("❌ Failed to connect to MongoDB. Aborting indexing.")
return output.getvalue()
# Create embedding processor
embed_object = SolutionEmbedding()
for collection, solution_type in zip(solution_collections, solution_types):
print(f"\n{'='*60}")
print(f"🔄 Processing {solution_type} data from MongoDB...")
print(f"{'='*60}")
try:
embeddings = embed_object.run_embedding(solution_type, self.mongodb_conn)
self.index(embeddings, collection)
print(f"✅ Completed indexing for {solution_type}")
except Exception as e:
print(f"❌ Error indexing {solution_type}: {e}")
import traceback
print(traceback.format_exc())
# Close MongoDB connection
if self.mongodb_conn:
self.mongodb_conn.close()
self.mongodb_conn = None
print(f"\n{'='*60}")
print("🎉 All solution indexing completed!")
print(f"{'='*60}")
except Exception as e:
print(f"❌ Fatal error during indexing: {e}")
import traceback
print(traceback.format_exc())
return output.getvalue()
def indexing_single_solution(self, solution: str, collection_name: str) -> str:
"""Indexing a single solution into its Qdrant collection from MongoDB"""
with OutputCapture() as output:
try:
print(f"{'='*60}")
print(f"🚀 Starting indexing for {solution}")
print(f"{'='*60}\n")
self.client.recreate_collection(
collection_name=collection_name,
vectors_config=qdrant_client.http.models.VectorParams(
size=768,
distance=qdrant_client.http.models.Distance.COSINE,
)
)
print(f"✅ Collection {collection_name} created\n")
# Setup MongoDB connection
if not self.mongodb_conn:
if not self.setup_mongodb():
print("❌ Failed to connect to MongoDB")
return output.getvalue()
# Create embedding processor
embed_object = SolutionEmbedding()
print(f"🔄 Processing {solution} data from MongoDB...")
embeddings = embed_object.run_embedding(solution, self.mongodb_conn)
print(f"\n📊 Indexing to Qdrant...")
self.index(embeddings, collection_name)
# Close MongoDB connection
if self.mongodb_conn:
self.mongodb_conn.close()
self.mongodb_conn = None
print(f"\n{'='*60}")
print(f"✅ Successfully completed indexing for {solution}")
print(f"{'='*60}")
except Exception as e:
print(f"❌ Error while indexing solution {solution}: {e}")
import traceback
print(traceback.format_exc())
return output.getvalue()
"""=================GRADIO UI========================"""
def create_indexing_interface():
"""Create Gradio interface for indexing from MongoDB"""
product_indexing = ProductIndexing()
solution_indexing = SolutionIndexing()
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🗄️ Qdrant Data Indexing System (MongoDB)")
gr.Markdown("Recreate Qdrant Collections and Index Data from MongoDB Atlas")
output_box = gr.Textbox(
lines=20,
label="📋 Logs",
interactive=False,
show_copy_button=True,
max_lines=30
)
gr.Markdown("---")
gr.Markdown("## 🏢 Giải pháp (Solutions)")
with gr.Row():
gr.Button("GP Ngư nghiệp").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("ngu_nghiep"), gr.State(QDRANT_COLLECTION_NAME_GPNGUNGHIEP)],
outputs=output_box)
gr.Button("GP Học đường").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("hoc_duong"), gr.State(QDRANT_COLLECTION_NAME_GPHOCDUONG)],
outputs=output_box)
gr.Button("GP Nhà thông minh").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("nha_thong_minh"), gr.State(QDRANT_COLLECTION_NAME_GPNHATHONGMINH)],
outputs=output_box)
gr.Button("GP Nông nghiệp CNC").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("nong_nghiep_cnc"), gr.State(QDRANT_COLLECTION_NAME_GPNNCNC)],
outputs=output_box)
with gr.Row():
gr.Button("GP Cảnh quan").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("canh_quan"), gr.State(QDRANT_COLLECTION_NAME_GPCANHQUAN)],
outputs=output_box)
gr.Button("GP HTĐ NLMT").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("nlmt"), gr.State(QDRANT_COLLECTION_NAME_GPNLMT)],
outputs=output_box)
gr.Button("GP Đường phố").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("duong_pho"), gr.State(QDRANT_COLLECTION_NAME_GPDUONGPHO)],
outputs=output_box)
gr.Button("GP Văn phòng công sở").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("van_phong_cong_so"), gr.State(QDRANT_COLLECTION_NAME_GPVPCS)],
outputs=output_box)
with gr.Row():
gr.Button("GP Nhà máy CN").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("nha_may_cong_nghiep"), gr.State(QDRANT_COLLECTION_NAME_GPNMCN)],
outputs=output_box)
gr.Button("GP Nhà ở xã hội").click(
solution_indexing.indexing_single_solution,
inputs=[gr.State("nha_o_xa_hoi"), gr.State(QDRANT_COLLECTION_NAME_GPNOXH)],
outputs=output_box)
gr.Button("✨ Tất cả GP", variant="primary").click(
solution_indexing.run_indexing,
inputs=gr.State(True),
outputs=output_box)
gr.Markdown("---")
gr.Markdown("## 📦 Sản phẩm (Products)")
# Individual product buttons
with gr.Row():
btn_phich = gr.Button("SP Phích nước")
btn_chieu_sang = gr.Button("SP Chiếu sáng")
btn_chuyen_dung = gr.Button("SP Chuyên dụng")
btn_ntm = gr.Button("SP Nhà thông minh")
btn_thiet_bi = gr.Button("SP Thiết bị điện")
with gr.Row():
btn_all_products = gr.Button("✨ Tất cả SP", variant="primary", scale=2)
# Setup click handlers
btn_phich.click(
product_indexing.indexing_single_product_type,
inputs=[gr.State("phich_nuoc"), gr.State(QDRANT_COLLECTION_NAME_SPPHICHNUOC), gr.State(True)],
outputs=output_box)
btn_chieu_sang.click(
product_indexing.indexing_single_product_type,
inputs=[gr.State("chieu_sang"), gr.State(QDRANT_COLLECTION_NAME_SPCHIEUSANG), gr.State(True)],
outputs=output_box)
btn_chuyen_dung.click(
product_indexing.indexing_single_product_type,
inputs=[gr.State("chuyen_dung"), gr.State(QDRANT_COLLECTION_NAME_SPCHUYENDUNG), gr.State(True)],
outputs=output_box)
btn_ntm.click(
product_indexing.indexing_single_product_type,
inputs=[gr.State("nha_thong_minh"), gr.State(QDRANT_COLLECTION_NAME_SPNHATHONGMINH), gr.State(True)],
outputs=output_box)
btn_thiet_bi.click(
product_indexing.indexing_single_product_type,
inputs=[gr.State("thiet_bi_dien"), gr.State(QDRANT_COLLECTION_NAME_SPTHIETBIDIEN), gr.State(True)],
outputs=output_box)
btn_all_products.click(
product_indexing.run_indexing,
inputs=[gr.State(True), gr.State(True)],
outputs=output_box)
return demo
if __name__ == "__main__":
demo = create_indexing_interface()
demo.launch()