|
|
import io |
|
|
import os |
|
|
import requests |
|
|
import sys |
|
|
|
|
|
|
|
|
from typing import List, Dict, Tuple, Any, Optional |
|
|
import uuid |
|
|
|
|
|
from PIL import Image |
|
|
from FlagEmbedding import BGEM3FlagModel |
|
|
import gradio as gr |
|
|
from langchain_core.documents import Document |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
import qdrant_client |
|
|
from qdrant_client.http.models import Modifier, Distance, SparseVectorParams, VectorParams, SparseIndexParams |
|
|
import torch |
|
|
from transformers import EfficientNetModel, AutoImageProcessor |
|
|
from pymongo import MongoClient |
|
|
import contextlib |
|
|
|
|
|
from config import ( |
|
|
QDRANT_COLLECTION_NAME_SPCHIEUSANG, |
|
|
QDRANT_COLLECTION_NAME_SPCHUYENDUNG, |
|
|
QDRANT_COLLECTION_NAME_SPPHICHNUOC, |
|
|
QDRANT_COLLECTION_NAME_SPTHIETBIDIEN, |
|
|
QDRANT_COLLECTION_NAME_SPNHATHONGMINH, |
|
|
QDRANT_COLLECTION_NAME_GPNHATHONGMINH, |
|
|
QDRANT_COLLECTION_NAME_GPHOCDUONG, |
|
|
QDRANT_COLLECTION_NAME_GPNGUNGHIEP, |
|
|
QDRANT_COLLECTION_NAME_GPCANHQUAN, |
|
|
QDRANT_COLLECTION_NAME_GPNLMT, |
|
|
QDRANT_COLLECTION_NAME_GPNNCNC, |
|
|
QDRANT_COLLECTION_NAME_GPDUONGPHO, |
|
|
QDRANT_COLLECTION_NAME_GPVPCS, |
|
|
QDRANT_COLLECTION_NAME_GPNMCN, |
|
|
QDRANT_COLLECTION_NAME_GPNOXH, |
|
|
IMAGE_EMBEDDING_SIZE, |
|
|
TEXT_EMBEDDING_SIZE, |
|
|
IMAGE_EMBEDDING_MODEL, |
|
|
TEXT_EMBEDDING_MODEL, |
|
|
MONGODB_URI, |
|
|
QDRANT_HOST, |
|
|
QDRANT_API_KEY |
|
|
) |
|
|
|
|
|
from data_helper import * |
|
|
|
|
|
|
|
|
client = qdrant_client.QdrantClient( |
|
|
url=QDRANT_HOST, |
|
|
api_key=QDRANT_API_KEY, |
|
|
timeout=300.0 |
|
|
) |
|
|
|
|
|
"""=================SETTINGS========================""" |
|
|
device = torch.device( |
|
|
"cuda" if torch.cuda.is_available() else |
|
|
"mps" if torch.mps.is_available() else |
|
|
"cpu" |
|
|
) |
|
|
|
|
|
product_vectors_config = { |
|
|
"product": qdrant_client.http.models.VectorParams( |
|
|
size=TEXT_EMBEDDING_SIZE, |
|
|
distance=Distance.COSINE |
|
|
), |
|
|
"image": qdrant_client.http.models.VectorParams( |
|
|
size=IMAGE_EMBEDDING_SIZE, |
|
|
distance=Distance.COSINE |
|
|
), |
|
|
"product_bgem3_dense": qdrant_client.http.models.VectorParams( |
|
|
size=1024, |
|
|
distance=Distance.COSINE, |
|
|
) |
|
|
} |
|
|
|
|
|
sparse_vectors_config={ |
|
|
"product_bgem3_sparse": SparseVectorParams( |
|
|
index=SparseIndexParams(on_disk=False), |
|
|
modifier = Modifier.IDF |
|
|
) |
|
|
} |
|
|
|
|
|
product_collections = [ |
|
|
QDRANT_COLLECTION_NAME_SPCHIEUSANG, |
|
|
QDRANT_COLLECTION_NAME_SPCHUYENDUNG, |
|
|
QDRANT_COLLECTION_NAME_SPPHICHNUOC, |
|
|
QDRANT_COLLECTION_NAME_SPTHIETBIDIEN, |
|
|
QDRANT_COLLECTION_NAME_SPNHATHONGMINH |
|
|
] |
|
|
|
|
|
product_types = [ |
|
|
"chieu_sang", |
|
|
"chuyen_dung", |
|
|
"phich_nuoc", |
|
|
"thiet_bi_dien", |
|
|
"nha_thong_minh" |
|
|
] |
|
|
|
|
|
|
|
|
mongodb_product_collections = { |
|
|
"chieu_sang": "sp_chieu_sang", |
|
|
"chuyen_dung": "sp_chuyen_dung", |
|
|
"phich_nuoc": "sp_phich_nuoc", |
|
|
"thiet_bi_dien": "sp_thiet_bi_dien", |
|
|
"nha_thong_minh": "sp_nha_thong_minh" |
|
|
} |
|
|
|
|
|
solution_collections = [ |
|
|
QDRANT_COLLECTION_NAME_GPCANHQUAN, |
|
|
QDRANT_COLLECTION_NAME_GPDUONGPHO, |
|
|
QDRANT_COLLECTION_NAME_GPHOCDUONG, |
|
|
QDRANT_COLLECTION_NAME_GPNHATHONGMINH, |
|
|
QDRANT_COLLECTION_NAME_GPNGUNGHIEP, |
|
|
QDRANT_COLLECTION_NAME_GPNLMT, |
|
|
QDRANT_COLLECTION_NAME_GPNNCNC, |
|
|
QDRANT_COLLECTION_NAME_GPVPCS, |
|
|
QDRANT_COLLECTION_NAME_GPNMCN, |
|
|
QDRANT_COLLECTION_NAME_GPNOXH |
|
|
] |
|
|
|
|
|
solution_types = [ |
|
|
"canh_quan", |
|
|
"duong_pho", |
|
|
"hoc_duong", |
|
|
"nha_thong_minh", |
|
|
"ngu_nghiep", |
|
|
"nlmt", |
|
|
"nong_nghiep_cnc", |
|
|
"van_phong_cong_so", |
|
|
"nha_may_cong_nghiep", |
|
|
"nha_o_xa_hoi" |
|
|
] |
|
|
|
|
|
|
|
|
mongodb_solution_collections = { |
|
|
"canh_quan": "gp_canh_quan", |
|
|
"duong_pho": "gp_duong_pho", |
|
|
"hoc_duong": "gp_hoc_duong", |
|
|
"nha_thong_minh": "gp_nha_thong_minh", |
|
|
"ngu_nghiep": "gp_ngu_nghiep", |
|
|
"nlmt": "gp_he_thong_dien_nlmt", |
|
|
"nong_nghiep_cnc": "gp_nong_nghiep_cnc", |
|
|
"van_phong_cong_so": "gp_van_phong_cong_so", |
|
|
"nha_may_cong_nghiep": "gp_nha_may_cong_nghiep", |
|
|
"nha_o_xa_hoi": "gp_nha_o_xa_hoi" |
|
|
} |
|
|
|
|
|
class OutputCapture: |
|
|
"""Context manager to capture stdout and stderr""" |
|
|
def __init__(self): |
|
|
self.buffer = io.StringIO() |
|
|
self.old_stdout = None |
|
|
self.old_stderr = None |
|
|
|
|
|
def __enter__(self): |
|
|
self.old_stdout = sys.stdout |
|
|
self.old_stderr = sys.stderr |
|
|
sys.stdout = self.buffer |
|
|
sys.stderr = self.buffer |
|
|
return self.buffer |
|
|
|
|
|
def __exit__(self, *args): |
|
|
sys.stdout = self.old_stdout |
|
|
sys.stderr = self.old_stderr |
|
|
|
|
|
def getvalue(self): |
|
|
return self.buffer.getvalue() |
|
|
|
|
|
"""=================MONGODB CONNECTION========================""" |
|
|
class MongoDBConnection: |
|
|
def __init__(self, connection_string: str = None, db_name: str = "product_database"): |
|
|
""" |
|
|
Initialize MongoDB connection |
|
|
Args: |
|
|
connection_string: MongoDB Atlas connection string |
|
|
db_name: Database name |
|
|
""" |
|
|
self.connection_string = MONGODB_URI if connection_string is None else connection_string |
|
|
self.db_name = db_name |
|
|
self.client = None |
|
|
self.db = None |
|
|
|
|
|
def connect(self): |
|
|
"""Establish connection to MongoDB""" |
|
|
try: |
|
|
self.client = MongoClient(self.connection_string) |
|
|
self.db = self.client[self.db_name] |
|
|
|
|
|
self.client.admin.command('ping') |
|
|
print(f"✅ Connected to MongoDB: {self.db_name}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"❌ Failed to connect to MongoDB: {e}") |
|
|
return False |
|
|
|
|
|
def get_collection_data(self, collection_name: str) -> List[Dict]: |
|
|
""" |
|
|
Retrieve all documents from a collection |
|
|
Args: |
|
|
collection_name: Name of the MongoDB collection |
|
|
Returns: |
|
|
List of documents |
|
|
""" |
|
|
try: |
|
|
collection = self.db[collection_name] |
|
|
data = list(collection.find({})) |
|
|
|
|
|
for item in data: |
|
|
if '_id' in item: |
|
|
item['_id'] = str(item['_id']) |
|
|
print(f"✅ Retrieved {len(data)} documents from {collection_name}") |
|
|
return data |
|
|
except Exception as e: |
|
|
print(f"❌ Error retrieving data from {collection_name}: {e}") |
|
|
return [] |
|
|
|
|
|
def close(self): |
|
|
"""Close MongoDB connection""" |
|
|
if self.client: |
|
|
self.client.close() |
|
|
print("✅ MongoDB connection closed") |
|
|
|
|
|
|
|
|
"""=================CLASS EMBEDDING========================""" |
|
|
class DataEmbedding: |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
def embed_text_batch(self, contents: List[str], batch_size: int = 32, hybrid_mode: bool = False) -> List[Optional[torch.Tensor]]: |
|
|
"""Create text embeddings using HuggingFaceEmbeddings (768 dimensions), and optionally BGEM3 (1024 dimensions) in batches.""" |
|
|
normal_embeddings, bgem3_dense_embeddings, bgem3_sparse_embeddings = [], [], [] |
|
|
|
|
|
|
|
|
valid_contents = [] |
|
|
valid_indices = [] |
|
|
for i, content in enumerate(contents): |
|
|
if content: |
|
|
valid_contents.append(content) |
|
|
valid_indices.append(i) |
|
|
|
|
|
if not valid_contents: |
|
|
return [None] * len(contents) |
|
|
|
|
|
try: |
|
|
text_embedding_model = HuggingFaceEmbeddings( |
|
|
model_name=TEXT_EMBEDDING_MODEL, |
|
|
model_kwargs={'device': device}, |
|
|
encode_kwargs={'normalize_embeddings': True} |
|
|
) |
|
|
if hybrid_mode: |
|
|
hybrid_embedding_model = BGEM3FlagModel( |
|
|
"BAAI/bge-m3", |
|
|
use_fp16=True, |
|
|
devices=str(device) |
|
|
) |
|
|
|
|
|
for i in range(0, len(valid_contents), batch_size): |
|
|
batch_contents = valid_contents[i:i+batch_size] |
|
|
|
|
|
bgem3_dense_embeddings_list, bgem3_sparse_embeddings_list = [], [] |
|
|
if hybrid_mode: |
|
|
bgem3_embeddings = hybrid_embedding_model.encode( |
|
|
sentences=batch_contents, |
|
|
return_dense=True, |
|
|
return_sparse=True |
|
|
) |
|
|
|
|
|
bgem3_dense_embeddings_list = bgem3_embeddings['dense_vecs'] |
|
|
bgem3_sparse_embeddings_list = bgem3_embeddings['lexical_weights'] |
|
|
bgem3_dense_embeddings.extend([ |
|
|
torch.tensor(emb, dtype=torch.float32) |
|
|
for emb in bgem3_dense_embeddings_list |
|
|
]) |
|
|
bgem3_sparse_embeddings.extend(bgem3_sparse_embeddings_list) |
|
|
|
|
|
normal_embeddings_list = text_embedding_model.embed_documents(batch_contents) |
|
|
normal_embeddings.extend([torch.tensor(emb, dtype=torch.float32) for emb in normal_embeddings_list]) |
|
|
|
|
|
|
|
|
result = [None] * len(contents) |
|
|
for i, valid_idx in enumerate(valid_indices): |
|
|
if hybrid_mode: |
|
|
result[valid_idx] = (normal_embeddings[i], bgem3_dense_embeddings[i], bgem3_sparse_embeddings[i]) |
|
|
else: |
|
|
result[valid_idx] = (normal_embeddings[i], [], []) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error in batch text embedding: {str(e)[:100]}...") |
|
|
return [] |
|
|
|
|
|
def embed_images_batch(self, image_urls: List[str], batch_size: int = 32) -> List[Optional[torch.Tensor]]: |
|
|
"""Create image embeddings in batches.""" |
|
|
all_embeddings: List[Optional[torch.Tensor]] = [None] * len(image_urls) |
|
|
|
|
|
|
|
|
images_to_process: List[Tuple[Any, int]] = [] |
|
|
for i, url in enumerate(image_urls): |
|
|
if url: |
|
|
try: |
|
|
response = requests.get(url, timeout=30) |
|
|
response.raise_for_status() |
|
|
image = Image.open(io.BytesIO(response.content)).convert('RGB') |
|
|
images_to_process.append((image, i)) |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"❌ HTTP error for url {url}: {e}") |
|
|
pass |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading image {url}: {e}") |
|
|
pass |
|
|
|
|
|
if not images_to_process: |
|
|
return all_embeddings |
|
|
|
|
|
image_processor = AutoImageProcessor.from_pretrained(IMAGE_EMBEDDING_MODEL) |
|
|
image_embedding_model = EfficientNetModel.from_pretrained(IMAGE_EMBEDDING_MODEL).to(device) |
|
|
|
|
|
for i in range(0, len(images_to_process), batch_size): |
|
|
batch_data = images_to_process[i:i+batch_size] |
|
|
batch_images = [d[0] for d in batch_data] |
|
|
batch_indices = [d[1] for d in batch_data] |
|
|
|
|
|
try: |
|
|
inputs = image_processor(images=batch_images, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = image_embedding_model(**inputs) |
|
|
|
|
|
embeddings = outputs.pooler_output.squeeze() |
|
|
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
|
for j, embedding in enumerate(normalized_embeddings): |
|
|
original_index = batch_indices[j] |
|
|
all_embeddings[original_index] = embedding.squeeze() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error embedding image batch: {e}") |
|
|
pass |
|
|
|
|
|
return all_embeddings |
|
|
|
|
|
|
|
|
class ProductEmbedding(DataEmbedding): |
|
|
def run_embedding(self, product_type: str, mongodb_conn: MongoDBConnection, |
|
|
batch_size: int = 32, hybrid_mode: bool = False): |
|
|
""" |
|
|
Generate embeddings for a specific product type from MongoDB |
|
|
Args: |
|
|
product_type: Type of product |
|
|
mongodb_conn: MongoDB connection object |
|
|
batch_size: Batch size for processing |
|
|
hybrid_mode: Whether to use hybrid text embedding (BGEM3) |
|
|
""" |
|
|
embeddings = [] |
|
|
|
|
|
processed_docs = self.prepare_docs( |
|
|
product_type=product_type, |
|
|
mongodb_conn=mongodb_conn |
|
|
) |
|
|
|
|
|
|
|
|
text_contents = [doc.page_content for doc in processed_docs] |
|
|
text_embeddings = self.embed_text_batch(text_contents, batch_size, hybrid_mode) |
|
|
|
|
|
|
|
|
image_urls = [doc.metadata.get("image_url") for doc in processed_docs] |
|
|
image_embeddings = self.embed_images_batch(image_urls) |
|
|
|
|
|
|
|
|
for i, doc in enumerate(processed_docs): |
|
|
if i < len(text_embeddings) and text_embeddings[i] is not None: |
|
|
normal_text_embedding, bgem3_dense_text_embedding, bgem3_sparse_text_embedding = text_embeddings[i] |
|
|
else: |
|
|
normal_text_embedding, bgem3_dense_text_embedding, bgem3_sparse_text_embedding = None, None, None |
|
|
|
|
|
image_embedding = image_embeddings[i] if i < len(image_embeddings) else None |
|
|
|
|
|
|
|
|
vectors = { |
|
|
"product": normal_text_embedding.tolist() if normal_text_embedding is not None else [0.0] * TEXT_EMBEDDING_SIZE, |
|
|
"product_bgem3_dense": bgem3_dense_text_embedding.tolist() if bgem3_dense_text_embedding is not None else [0.0] * 1024, |
|
|
"image": image_embedding.tolist() if image_embedding is not None else [0.0] * IMAGE_EMBEDDING_SIZE |
|
|
} |
|
|
|
|
|
if bgem3_sparse_text_embedding is not None and bgem3_sparse_text_embedding: |
|
|
sparse_vectors = { |
|
|
"product_bgem3_sparse": { |
|
|
"indices": list(bgem3_sparse_text_embedding.keys()), |
|
|
"values": [float(v) for v in bgem3_sparse_text_embedding.values()] |
|
|
} |
|
|
} |
|
|
else: |
|
|
sparse_vectors = {"product_sparse": {"indices": [], "values": []}} |
|
|
|
|
|
|
|
|
payload = { |
|
|
"product": doc.page_content, |
|
|
"metadata": {key: value for key, value in doc.metadata.items()} |
|
|
} |
|
|
|
|
|
|
|
|
embeddings.append({ |
|
|
"point_id": str(uuid.uuid4()), |
|
|
"vectors": vectors, |
|
|
"sparse_vectors": sparse_vectors, |
|
|
"payload": payload |
|
|
}) |
|
|
|
|
|
print(f"Generated {len(embeddings)} embeddings for {product_type}") |
|
|
return embeddings |
|
|
|
|
|
def prepare_docs(self, product_type: str, mongodb_conn: MongoDBConnection): |
|
|
""" |
|
|
Prepare documents from MongoDB |
|
|
Args: |
|
|
product_type: Type of product |
|
|
mongodb_conn: MongoDB connection object |
|
|
""" |
|
|
if not mongodb_conn or mongodb_conn.db is None: |
|
|
raise ValueError("MongoDB connection not established") |
|
|
|
|
|
collection_name = mongodb_product_collections.get(product_type) |
|
|
if not collection_name: |
|
|
raise ValueError(f"No MongoDB collection mapping for product type: {product_type}") |
|
|
|
|
|
data = mongodb_conn.get_collection_data(collection_name) |
|
|
print(f"🗄️ Loaded data from MongoDB collection: {collection_name}") |
|
|
|
|
|
docs = [] |
|
|
EXCLUDE_FROM_FLATTENING = {"tags"} |
|
|
for item in data: |
|
|
content = self.create_content(item) |
|
|
metadata = self.extract_metadata(item, product_type) |
|
|
|
|
|
flat_metadata = {**metadata} |
|
|
for key, value in metadata.items(): |
|
|
if isinstance(value, dict) and key not in EXCLUDE_FROM_FLATTENING: |
|
|
flat_metadata.update({f"{key}_{sub_key}": sub_value for sub_key, sub_value in value.items()}) |
|
|
|
|
|
doc = Document(page_content=content, metadata=flat_metadata) |
|
|
docs.append(doc) |
|
|
|
|
|
print(f"Prepared {len(docs)} documents") |
|
|
return docs |
|
|
|
|
|
def create_content(self, item: Dict) -> str: |
|
|
"""Tạo document content cho sản phẩm""" |
|
|
product_name = item.get("Tên sản phẩm", "") |
|
|
model = item.get("Mã Sản Phẩm", "") |
|
|
summary_specs = item.get("Tóm tắt TSKT", "") |
|
|
summary_advantages = item.get("Tóm tắt ưu điểm, tính năng", "") |
|
|
specs = item.get("Thông số kỹ thuật", "") |
|
|
advantages = item.get("Nội dung Ưu điểm SP\n(- File word/Excel\n- Đặt tên file theo mã SAP)", "") |
|
|
instruction = item.get("HDSD", "") |
|
|
content = ( |
|
|
f"# Tên sản phẩm: {product_name}\n\n" |
|
|
f"## Mã sản phẩm: {model}\n\n" |
|
|
f"## Tóm tắt TSKT\n{summary_specs}\n\n" |
|
|
f"### Thông số kỹ thuật chi tiết\n{specs}\n\n" |
|
|
f"## Tóm tắt ưu điểm & tính năng\n{summary_advantages}\n\n" |
|
|
f"### Ưu điểm & tính năng chi tiết\n{advantages}\n" |
|
|
f"## Hướng dẫn sử dụng: \n{instruction}\n" |
|
|
) |
|
|
|
|
|
return content |
|
|
|
|
|
def extract_metadata(self, item: Dict, product_type: str) -> Dict: |
|
|
"""Extract metadata from a product item""" |
|
|
additional_info = ProductEmbedding.process_additional_metadata(item, product_type) |
|
|
tags = item.get("Tags", {}) |
|
|
common_metadata = { |
|
|
"prod_id": item.get("Product_ID", None), |
|
|
"ten_san_pham": item.get("Tên sản phẩm", ""), |
|
|
"model": item.get("Mã Sản Phẩm", ""), |
|
|
"danh_muc_l1": item.get("category 1", ""), |
|
|
"danh_muc_l2": item.get("category 2", ""), |
|
|
"danh_muc_l3": item.get("category 3", ""), |
|
|
"url": str(item.get("Link sản phẩm", "")).strip(), |
|
|
"image_url": item.get("Link ảnh sản phẩm"), |
|
|
"buy_url": item.get("Link mua hàng online", ""), |
|
|
"gia": item.get("Giá", ""), |
|
|
"tags": tags, |
|
|
**tags, |
|
|
**additional_info |
|
|
} |
|
|
return common_metadata |
|
|
|
|
|
@staticmethod |
|
|
def process_additional_metadata(item: Dict[str, Any], product_type) -> Dict[str, Any]: |
|
|
"""Process an item and extract additional information""" |
|
|
tags = item.get("Tags", {}) |
|
|
spec_text = item.get("Tóm tắt TSKT", "") |
|
|
model = item.get("Mã Sản Phẩm", "") |
|
|
prod_name = item.get("Tên sản phẩm", "") |
|
|
additional_info = {} |
|
|
|
|
|
|
|
|
if "cong_suat" not in tags.keys() or tags["cong_suat"] == "": |
|
|
power = extract_power(spec_text) |
|
|
if power is not None: |
|
|
additional_info["cong_suat"] = power |
|
|
|
|
|
|
|
|
if product_type == "phich_nuoc": |
|
|
pass |
|
|
|
|
|
elif product_type == "chieu_sang": |
|
|
ceiling_hole_diameter = extract_ceiling_hole_diameter2(spec_text) |
|
|
if ceiling_hole_diameter is not None: |
|
|
additional_info["duong_kinh_lo_khoet_tran"] = ceiling_hole_diameter |
|
|
|
|
|
tinh_nang = extract_tinh_nang(model, prod_name) |
|
|
if tinh_nang is not None: |
|
|
additional_info["tinh_nang"] = tinh_nang |
|
|
|
|
|
elif product_type == "chuyen_dung": |
|
|
he_thong_hoa_luoi_pha = extract_he_thong_hoa_luoi_pha(prod_name) |
|
|
if he_thong_hoa_luoi_pha is not None: |
|
|
additional_info["he_thong_hoa_luoi_pha"] = he_thong_hoa_luoi_pha |
|
|
|
|
|
elif product_type == "thiet_bi_dien": |
|
|
dong_danh_dinh = extract_dong_danh_dinh(spec_text) |
|
|
if dong_danh_dinh is not None: |
|
|
additional_info["dong_danh_dinh"] = dong_danh_dinh |
|
|
|
|
|
elif product_type == "nha_thong_minh": |
|
|
cable_length = extract_cable_length(spec_text) |
|
|
if cable_length is not None: |
|
|
additional_info["chieu_dai_day"] = cable_length |
|
|
|
|
|
plugs_max_current = extract_plugs_max_current(spec_text) |
|
|
if plugs_max_current is not None: |
|
|
additional_info["dong_dien_o_cam_toi_da"] = plugs_max_current |
|
|
|
|
|
voltage = extract_voltage(model) |
|
|
if voltage is not None: |
|
|
additional_info["dien_ap"] = voltage |
|
|
|
|
|
return additional_info |
|
|
|
|
|
|
|
|
class SolutionEmbedding(DataEmbedding): |
|
|
def run_embedding(self, solution_type: str, mongodb_conn: MongoDBConnection, batch_size: int = 32): |
|
|
"""Generate embeddings for a specific solution type from MongoDB""" |
|
|
embeddings = [] |
|
|
|
|
|
processed_docs, docs_to_embed = self.prepare_docs(solution_type, mongodb_conn) |
|
|
|
|
|
embedding_contents = [doc.page_content for doc in docs_to_embed] |
|
|
text_embeddings = self.embed_text_batch(embedding_contents, batch_size) |
|
|
|
|
|
|
|
|
for i, doc in enumerate(processed_docs): |
|
|
embedding_tuple = text_embeddings[i] if i < len(text_embeddings) else None |
|
|
text_embedding = embedding_tuple[0] if embedding_tuple is not None else None |
|
|
|
|
|
|
|
|
payload = { |
|
|
"content": doc.page_content, |
|
|
"metadata": {key: value for key, value in doc.metadata.items()} |
|
|
} |
|
|
|
|
|
|
|
|
embeddings.append({ |
|
|
"point_id": str(uuid.uuid4()), |
|
|
"vectors": text_embedding.tolist() if text_embedding is not None else [0.0] * 768, |
|
|
"payload": payload |
|
|
}) |
|
|
|
|
|
print(f"Generated {len(embeddings)} embeddings for {solution_type}") |
|
|
return embeddings |
|
|
|
|
|
def prepare_docs(self, solution_type: str, mongodb_conn: MongoDBConnection): |
|
|
""" |
|
|
Prepare documents from MongoDB |
|
|
Args: |
|
|
solution_type: Type of solution |
|
|
mongodb_conn: MongoDB connection object |
|
|
""" |
|
|
if not mongodb_conn or mongodb_conn.db is None: |
|
|
raise ValueError("MongoDB connection not established") |
|
|
|
|
|
collection_name = mongodb_solution_collections.get(solution_type) |
|
|
if not collection_name: |
|
|
raise ValueError(f"No MongoDB collection mapping for solution type: {solution_type}") |
|
|
|
|
|
data = mongodb_conn.get_collection_data(collection_name) |
|
|
print(f"🗄️ Loaded solution data from MongoDB collection: {collection_name}") |
|
|
|
|
|
docs = [] |
|
|
docs_to_embed = [] |
|
|
|
|
|
for item in data: |
|
|
|
|
|
for key, val in item.items(): |
|
|
if key in ["_id", "san_pham"]: |
|
|
continue |
|
|
|
|
|
if isinstance(val, list): |
|
|
for d in val: |
|
|
page_content = ". ".join([f"{k}: {v}" for k, v in d.items()]) |
|
|
docs.append( |
|
|
Document( |
|
|
page_content=page_content, |
|
|
metadata={"category": key} |
|
|
) |
|
|
) |
|
|
|
|
|
if key != "faq": |
|
|
docs_to_embed.append( |
|
|
Document( |
|
|
page_content=page_content, |
|
|
metadata={"category": key} |
|
|
) |
|
|
) |
|
|
else: |
|
|
page_content = f"Câu hỏi: {d.get('Câu hỏi', '')}" |
|
|
docs_to_embed.append( |
|
|
Document( |
|
|
page_content=page_content, |
|
|
metadata={"category": key} |
|
|
) |
|
|
) |
|
|
|
|
|
elif isinstance(val, dict): |
|
|
for k, v in val.items(): |
|
|
docs_to_embed.append(Document(page_content=f"{k}: {v}", metadata={"category": key})) |
|
|
docs.append(Document(page_content=f"{k}: {v}", metadata={"category": key})) |
|
|
|
|
|
print(f"Prepared {len(docs)} documents") |
|
|
return docs, docs_to_embed |
|
|
|
|
|
|
|
|
"""=================CLASS INDEXING========================""" |
|
|
class ProductIndexing: |
|
|
def __init__(self, vector_db_client=client): |
|
|
super().__init__() |
|
|
self.client = vector_db_client |
|
|
self.mongodb_conn = None |
|
|
|
|
|
def setup_mongodb(self, connection_string: str = None): |
|
|
"""Setup MongoDB connection""" |
|
|
self.mongodb_conn = MongoDBConnection(connection_string) |
|
|
return self.mongodb_conn.connect() |
|
|
|
|
|
def index( |
|
|
self, |
|
|
embeddings: List[Dict], |
|
|
collection_name: str, |
|
|
batch_size: int = 100 |
|
|
): |
|
|
"""Index embeddings to a Qdrant collection in batches""" |
|
|
|
|
|
total_docs = len(embeddings) |
|
|
success_count = 0 |
|
|
error_count = 0 |
|
|
|
|
|
print(f"Adding {total_docs} multimodal documents to '{collection_name}'...") |
|
|
|
|
|
for i in range(0, total_docs, batch_size): |
|
|
batch = embeddings[i:i+batch_size] |
|
|
points = [] |
|
|
|
|
|
try: |
|
|
for embedding_data in batch: |
|
|
combined_vectors = embedding_data["vectors"].copy() |
|
|
combined_vectors.update(embedding_data["sparse_vectors"]) |
|
|
|
|
|
point = qdrant_client.http.models.PointStruct( |
|
|
id=embedding_data["point_id"], |
|
|
vector=combined_vectors, |
|
|
payload=embedding_data["payload"] |
|
|
) |
|
|
points.append(point) |
|
|
|
|
|
if points: |
|
|
self.client.upsert(collection_name=collection_name, points=points) |
|
|
success_count += len(batch) |
|
|
|
|
|
text_count = sum(1 for p in points if any(v != 0 for v in p.vector["product"])) |
|
|
image_count = sum(1 for p in points if any(v != 0 for v in p.vector["image"])) |
|
|
|
|
|
print(f"✅ Batch {i//batch_size + 1}: {len(batch)} docs | {text_count} product | {image_count} images") |
|
|
else: |
|
|
print(f"⚠️ Batch {i//batch_size + 1}: No valid points to upload") |
|
|
|
|
|
except Exception as e: |
|
|
error_count += len(batch) |
|
|
print(f"❌ Batch {i//batch_size + 1} failed: {e}") |
|
|
|
|
|
print(f"\n📊 Final Results:") |
|
|
print(f" ✅ Successful: {success_count}") |
|
|
print(f" ❌ Failed: {error_count}") |
|
|
print(f" 📈 Success Rate: {success_count/(success_count+error_count)*100:.1f}%") |
|
|
|
|
|
def run_indexing(self, reload: bool = True, hybrid_mode: bool = True): |
|
|
""" |
|
|
Index all product data from MongoDB into Qdrant collections. |
|
|
Args: |
|
|
reload: Whether to recreate collections |
|
|
hybrid_mode: Whether to use hybrid text embedding (BGEM3) |
|
|
""" |
|
|
with OutputCapture() as output: |
|
|
try: |
|
|
if reload: |
|
|
try: |
|
|
for collection in product_collections: |
|
|
self.client.recreate_collection( |
|
|
collection_name=collection, |
|
|
vectors_config=product_vectors_config, |
|
|
sparse_vectors_config=sparse_vectors_config |
|
|
) |
|
|
print("✅ All product collections recreated.") |
|
|
except Exception as e: |
|
|
print(f"❌ Error while recreating collections: {e}") |
|
|
return output.getvalue() |
|
|
|
|
|
|
|
|
if not self.mongodb_conn: |
|
|
if not self.setup_mongodb(): |
|
|
print("❌ Failed to connect to MongoDB. Aborting indexing.") |
|
|
return output.getvalue() |
|
|
|
|
|
|
|
|
embed_object = ProductEmbedding() |
|
|
|
|
|
for collection, product_type in zip(product_collections, product_types): |
|
|
print(f"\n{'='*60}") |
|
|
print(f"🔄 Processing {product_type} data from MongoDB...") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
try: |
|
|
|
|
|
embeddings = embed_object.run_embedding( |
|
|
product_type=product_type, |
|
|
mongodb_conn=self.mongodb_conn, |
|
|
hybrid_mode=hybrid_mode |
|
|
) |
|
|
|
|
|
|
|
|
self.index(embeddings, collection) |
|
|
self._create_payload_indexes_for_product_type(product_type, collection) |
|
|
|
|
|
print(f"✅ Completed indexing for {product_type}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error indexing {product_type}: {e}") |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
|
|
|
if self.mongodb_conn: |
|
|
self.mongodb_conn.close() |
|
|
self.mongodb_conn = None |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("🎉 All indexing completed!") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Fatal error during indexing: {e}") |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
return output.getvalue() |
|
|
|
|
|
def indexing_single_product_type(self, product_type: str, collection_name: str, |
|
|
hybrid_mode: bool = True) -> str: |
|
|
""" |
|
|
Indexing a single product group into its Qdrant collection from MongoDB |
|
|
Args: |
|
|
product_type: Type of product |
|
|
collection_name: Qdrant collection name |
|
|
hybrid_mode: Whether to use hybrid text embedding (BGEM3) |
|
|
""" |
|
|
with OutputCapture() as output: |
|
|
try: |
|
|
print(f"{'='*60}") |
|
|
print(f"🚀 Starting indexing for {product_type}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
self.client.recreate_collection( |
|
|
collection_name=collection_name, |
|
|
vectors_config=product_vectors_config, |
|
|
sparse_vectors_config=sparse_vectors_config |
|
|
) |
|
|
print(f"✅ Collection {collection_name} created\n") |
|
|
|
|
|
|
|
|
if not self.mongodb_conn: |
|
|
if not self.setup_mongodb(): |
|
|
print("❌ Failed to connect to MongoDB") |
|
|
return output.getvalue() |
|
|
|
|
|
|
|
|
embed_object = ProductEmbedding() |
|
|
|
|
|
print(f"🔄 Processing {product_type} data from MongoDB...") |
|
|
embeddings = embed_object.run_embedding( |
|
|
product_type=product_type, |
|
|
mongodb_conn=self.mongodb_conn, |
|
|
hybrid_mode=hybrid_mode |
|
|
) |
|
|
|
|
|
print(f"\n📊 Indexing to Qdrant...") |
|
|
self.index(embeddings, collection_name) |
|
|
|
|
|
print(f"\n🔍 Creating payload indexes...") |
|
|
self._create_payload_indexes_for_product_type(product_type, collection_name) |
|
|
|
|
|
|
|
|
if self.mongodb_conn: |
|
|
self.mongodb_conn.close() |
|
|
self.mongodb_conn = None |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"✅ Successfully completed indexing for {product_type}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error while indexing product type {product_type}: {e}") |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
return output.getvalue() |
|
|
|
|
|
def _create_payload_indexes_for_product_type(self, product_type: str, collection_name: str): |
|
|
"""Create payload indexes based on product type field schemas""" |
|
|
|
|
|
print(f"🔍 Creating payload indexes for {product_type}...") |
|
|
|
|
|
try: |
|
|
|
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.danh_muc_l2", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
|
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.danh_muc_l3", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
|
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.gia", |
|
|
field_schema=qdrant_client.http.models.IntegerIndexParams(type="integer") |
|
|
) |
|
|
|
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.cong_suat", |
|
|
field_schema=qdrant_client.http.models.FloatIndexParams(type="float") |
|
|
) |
|
|
|
|
|
|
|
|
if product_type == "phich_nuoc": |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.dung_tich", |
|
|
field_schema=qdrant_client.http.models.FloatIndexParams(type="float") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.chat_lieu", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.tinh_nang", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
|
|
|
elif product_type == "chieu_sang": |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.kich_thuoc", |
|
|
field_schema=qdrant_client.http.models.FloatIndexParams(type="float") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.duong_kinh_lo_khoet_tran", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.tinh_nang", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
|
|
|
elif product_type == "chuyen_dung": |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.nhiet_do_mau", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.dien_ap", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.cong_nghe_led", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.loai_den", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.he_thong_hoa_luoi", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
|
|
|
elif product_type == "thiet_bi_dien": |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.dong_danh_dinh", |
|
|
field_schema=qdrant_client.http.models.FloatIndexParams(type="float") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.anh_sang", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.so_hat", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.so_cuc", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.modules", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.doi_tuong", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.cong_nghe", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.loai_den", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.san_pham", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
|
|
|
elif product_type == "nha_thong_minh": |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.chieu_dai_day", |
|
|
field_schema=qdrant_client.http.models.FloatIndexParams(type="float") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.lo_khoet_tran", |
|
|
field_schema=qdrant_client.http.models.IntegerIndexParams(type="integer") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.nut_bam", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.dong_dien_o_cam_toi_da", |
|
|
field_schema=qdrant_client.http.models.IntegerIndexParams(type="integer") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.dien_ap", |
|
|
field_schema=qdrant_client.http.models.IntegerIndexParams(type="integer") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.hinh_dang", |
|
|
field_schema=qdrant_client.http.models.KeywordIndexParams(type="keyword") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.tinh_nang", |
|
|
field_schema=qdrant_client.http.models.TextIndexParams(type="text") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.goc_chieu", |
|
|
field_schema=qdrant_client.http.models.TextIndexParams(type="text") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.combo", |
|
|
field_schema=qdrant_client.http.models.TextIndexParams(type="text") |
|
|
) |
|
|
self.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="metadata.anh_sang", |
|
|
field_schema=qdrant_client.http.models.TextIndexParams(type="text") |
|
|
) |
|
|
|
|
|
print(f"✅ All payload indexes created for {product_type}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error creating payload indexes for {product_type}: {e}") |
|
|
|
|
|
class SolutionIndexing: |
|
|
def __init__(self, vector_db_client=client): |
|
|
super().__init__() |
|
|
self.client = vector_db_client |
|
|
self.mongodb_conn = None |
|
|
|
|
|
def setup_mongodb(self, connection_string: str = None): |
|
|
"""Setup MongoDB connection""" |
|
|
self.mongodb_conn = MongoDBConnection(connection_string) |
|
|
return self.mongodb_conn.connect() |
|
|
|
|
|
def index( |
|
|
self, |
|
|
embeddings: List[Dict], |
|
|
collection_name: str, |
|
|
batch_size: int = 10 |
|
|
): |
|
|
"""Index embeddings to a Qdrant collection in batches""" |
|
|
|
|
|
total_docs = len(embeddings) |
|
|
success_count = 0 |
|
|
error_count = 0 |
|
|
|
|
|
print(f"Adding {total_docs} solution documents to '{collection_name}'...") |
|
|
|
|
|
for i in range(0, total_docs, batch_size): |
|
|
batch = embeddings[i:i+batch_size] |
|
|
points = [] |
|
|
|
|
|
try: |
|
|
for embedding_data in batch: |
|
|
|
|
|
point = qdrant_client.http.models.PointStruct( |
|
|
id=embedding_data["point_id"], |
|
|
vector=embedding_data["vectors"], |
|
|
payload=embedding_data["payload"] |
|
|
) |
|
|
points.append(point) |
|
|
|
|
|
|
|
|
if points: |
|
|
self.client.upsert(collection_name=collection_name, points=points) |
|
|
success_count += len(batch) |
|
|
|
|
|
|
|
|
text_count = sum(1 for p in points if any(v != 0 for v in p.vector)) |
|
|
|
|
|
print(f"✅ Batch {i//batch_size + 1}: {len(batch)} docs | {text_count} contents") |
|
|
else: |
|
|
print(f"⚠️ Batch {i//batch_size + 1}: No valid points to upload") |
|
|
|
|
|
except Exception as e: |
|
|
error_count += len(batch) |
|
|
print(f"❌ Batch {i//batch_size + 1} failed: {e}") |
|
|
|
|
|
print(f"\n📊 Final Results:") |
|
|
print(f" ✅ Successful: {success_count}") |
|
|
print(f" ❌ Failed: {error_count}") |
|
|
print(f" 📈 Success Rate: {success_count/(success_count+error_count)*100:.1f}%") |
|
|
|
|
|
def run_indexing(self, reload: bool = True): |
|
|
"""Index all solution data from MongoDB into Qdrant collections.""" |
|
|
with OutputCapture() as output: |
|
|
try: |
|
|
if reload: |
|
|
try: |
|
|
for collection in solution_collections: |
|
|
self.client.recreate_collection( |
|
|
collection_name=collection, |
|
|
vectors_config=qdrant_client.http.models.VectorParams( |
|
|
size=768, |
|
|
distance=qdrant_client.http.models.Distance.COSINE, |
|
|
) |
|
|
) |
|
|
print("✅ All solution collections recreated.") |
|
|
except Exception as e: |
|
|
print(f"❌ Error while recreating collections: {e}") |
|
|
return output.getvalue() |
|
|
|
|
|
|
|
|
if not self.mongodb_conn: |
|
|
if not self.setup_mongodb(): |
|
|
print("❌ Failed to connect to MongoDB. Aborting indexing.") |
|
|
return output.getvalue() |
|
|
|
|
|
|
|
|
embed_object = SolutionEmbedding() |
|
|
|
|
|
for collection, solution_type in zip(solution_collections, solution_types): |
|
|
print(f"\n{'='*60}") |
|
|
print(f"🔄 Processing {solution_type} data from MongoDB...") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
try: |
|
|
embeddings = embed_object.run_embedding(solution_type, self.mongodb_conn) |
|
|
self.index(embeddings, collection) |
|
|
print(f"✅ Completed indexing for {solution_type}") |
|
|
except Exception as e: |
|
|
print(f"❌ Error indexing {solution_type}: {e}") |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
|
|
|
if self.mongodb_conn: |
|
|
self.mongodb_conn.close() |
|
|
self.mongodb_conn = None |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("🎉 All solution indexing completed!") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Fatal error during indexing: {e}") |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
return output.getvalue() |
|
|
|
|
|
def indexing_single_solution(self, solution: str, collection_name: str) -> str: |
|
|
"""Indexing a single solution into its Qdrant collection from MongoDB""" |
|
|
with OutputCapture() as output: |
|
|
try: |
|
|
print(f"{'='*60}") |
|
|
print(f"🚀 Starting indexing for {solution}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
self.client.recreate_collection( |
|
|
collection_name=collection_name, |
|
|
vectors_config=qdrant_client.http.models.VectorParams( |
|
|
size=768, |
|
|
distance=qdrant_client.http.models.Distance.COSINE, |
|
|
) |
|
|
) |
|
|
print(f"✅ Collection {collection_name} created\n") |
|
|
|
|
|
|
|
|
if not self.mongodb_conn: |
|
|
if not self.setup_mongodb(): |
|
|
print("❌ Failed to connect to MongoDB") |
|
|
return output.getvalue() |
|
|
|
|
|
|
|
|
embed_object = SolutionEmbedding() |
|
|
|
|
|
print(f"🔄 Processing {solution} data from MongoDB...") |
|
|
embeddings = embed_object.run_embedding(solution, self.mongodb_conn) |
|
|
|
|
|
print(f"\n📊 Indexing to Qdrant...") |
|
|
self.index(embeddings, collection_name) |
|
|
|
|
|
|
|
|
if self.mongodb_conn: |
|
|
self.mongodb_conn.close() |
|
|
self.mongodb_conn = None |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"✅ Successfully completed indexing for {solution}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error while indexing solution {solution}: {e}") |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
return output.getvalue() |
|
|
|
|
|
|
|
|
"""=================GRADIO UI========================""" |
|
|
def create_indexing_interface(): |
|
|
"""Create Gradio interface for indexing from MongoDB""" |
|
|
product_indexing = ProductIndexing() |
|
|
solution_indexing = SolutionIndexing() |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🗄️ Qdrant Data Indexing System (MongoDB)") |
|
|
gr.Markdown("Recreate Qdrant Collections and Index Data from MongoDB Atlas") |
|
|
|
|
|
output_box = gr.Textbox( |
|
|
lines=20, |
|
|
label="📋 Logs", |
|
|
interactive=False, |
|
|
show_copy_button=True, |
|
|
max_lines=30 |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("## 🏢 Giải pháp (Solutions)") |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Button("GP Ngư nghiệp").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("ngu_nghiep"), gr.State(QDRANT_COLLECTION_NAME_GPNGUNGHIEP)], |
|
|
outputs=output_box) |
|
|
|
|
|
gr.Button("GP Học đường").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("hoc_duong"), gr.State(QDRANT_COLLECTION_NAME_GPHOCDUONG)], |
|
|
outputs=output_box) |
|
|
|
|
|
gr.Button("GP Nhà thông minh").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("nha_thong_minh"), gr.State(QDRANT_COLLECTION_NAME_GPNHATHONGMINH)], |
|
|
outputs=output_box) |
|
|
|
|
|
gr.Button("GP Nông nghiệp CNC").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("nong_nghiep_cnc"), gr.State(QDRANT_COLLECTION_NAME_GPNNCNC)], |
|
|
outputs=output_box) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Button("GP Cảnh quan").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("canh_quan"), gr.State(QDRANT_COLLECTION_NAME_GPCANHQUAN)], |
|
|
outputs=output_box) |
|
|
|
|
|
gr.Button("GP HTĐ NLMT").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("nlmt"), gr.State(QDRANT_COLLECTION_NAME_GPNLMT)], |
|
|
outputs=output_box) |
|
|
|
|
|
gr.Button("GP Đường phố").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("duong_pho"), gr.State(QDRANT_COLLECTION_NAME_GPDUONGPHO)], |
|
|
outputs=output_box) |
|
|
|
|
|
gr.Button("GP Văn phòng công sở").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("van_phong_cong_so"), gr.State(QDRANT_COLLECTION_NAME_GPVPCS)], |
|
|
outputs=output_box) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Button("GP Nhà máy CN").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("nha_may_cong_nghiep"), gr.State(QDRANT_COLLECTION_NAME_GPNMCN)], |
|
|
outputs=output_box) |
|
|
|
|
|
gr.Button("GP Nhà ở xã hội").click( |
|
|
solution_indexing.indexing_single_solution, |
|
|
inputs=[gr.State("nha_o_xa_hoi"), gr.State(QDRANT_COLLECTION_NAME_GPNOXH)], |
|
|
outputs=output_box) |
|
|
|
|
|
gr.Button("✨ Tất cả GP", variant="primary").click( |
|
|
solution_indexing.run_indexing, |
|
|
inputs=gr.State(True), |
|
|
outputs=output_box) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("## 📦 Sản phẩm (Products)") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
btn_phich = gr.Button("SP Phích nước") |
|
|
btn_chieu_sang = gr.Button("SP Chiếu sáng") |
|
|
btn_chuyen_dung = gr.Button("SP Chuyên dụng") |
|
|
btn_ntm = gr.Button("SP Nhà thông minh") |
|
|
btn_thiet_bi = gr.Button("SP Thiết bị điện") |
|
|
|
|
|
with gr.Row(): |
|
|
btn_all_products = gr.Button("✨ Tất cả SP", variant="primary", scale=2) |
|
|
|
|
|
|
|
|
btn_phich.click( |
|
|
product_indexing.indexing_single_product_type, |
|
|
inputs=[gr.State("phich_nuoc"), gr.State(QDRANT_COLLECTION_NAME_SPPHICHNUOC), gr.State(True)], |
|
|
outputs=output_box) |
|
|
|
|
|
btn_chieu_sang.click( |
|
|
product_indexing.indexing_single_product_type, |
|
|
inputs=[gr.State("chieu_sang"), gr.State(QDRANT_COLLECTION_NAME_SPCHIEUSANG), gr.State(True)], |
|
|
outputs=output_box) |
|
|
|
|
|
btn_chuyen_dung.click( |
|
|
product_indexing.indexing_single_product_type, |
|
|
inputs=[gr.State("chuyen_dung"), gr.State(QDRANT_COLLECTION_NAME_SPCHUYENDUNG), gr.State(True)], |
|
|
outputs=output_box) |
|
|
|
|
|
btn_ntm.click( |
|
|
product_indexing.indexing_single_product_type, |
|
|
inputs=[gr.State("nha_thong_minh"), gr.State(QDRANT_COLLECTION_NAME_SPNHATHONGMINH), gr.State(True)], |
|
|
outputs=output_box) |
|
|
|
|
|
btn_thiet_bi.click( |
|
|
product_indexing.indexing_single_product_type, |
|
|
inputs=[gr.State("thiet_bi_dien"), gr.State(QDRANT_COLLECTION_NAME_SPTHIETBIDIEN), gr.State(True)], |
|
|
outputs=output_box) |
|
|
|
|
|
btn_all_products.click( |
|
|
product_indexing.run_indexing, |
|
|
inputs=[gr.State(True), gr.State(True)], |
|
|
outputs=output_box) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_indexing_interface() |
|
|
demo.launch() |