Browse Source

1.引入自动模板优化模型
2.尝试大模型自动补全

Air 3 weeks ago
parent
commit
d5454073ac

+ 29 - 4
data/analyze_data/analyze_xls.py

@@ -1,14 +1,17 @@
 import pandas as pd
 
+from src.kg_construction.llm_construct_kg import sort_format
+
+
 def readXls(path):
     # 读取xls文件
     df = pd.read_excel(path)
 
     # 分隔
-    df.iloc[:, 0] = df.iloc[:, 0].astype(str).str.split('+')
-    df.iloc[:, 2] = df.iloc[:, 2].astype(str).str.split('/')
-    df.iloc[:, 4] = df.iloc[:, 4].astype(str).str.split('/')
-    df.iloc[:, 6] = df.iloc[:, 6].astype(str).str.split('/')
+    #df.iloc[:, 0] = df.iloc[:, 0].astype(str).str.split('+')
+    #df.iloc[:, 2] = df.iloc[:, 2].astype(str).str.split('/')
+    #df.iloc[:, 4] = df.iloc[:, 4].astype(str).str.split('/')
+    #df.iloc[:, 6] = df.iloc[:, 6].astype(str).str.split('/')
 
     # 将DataFrame数据转换为字典列表
     data_list = df.to_dict('records')
@@ -16,5 +19,27 @@ def readXls(path):
     return data_list
 
 
+def analyze_entity(data_list):
+    set_post = set()
+    set_job_category = set()
+    set_company_industry = set()
+    set_company_name = set()
+    set_company_nature = set()
+    set_city = set()
+    for data in data_list:
+        set_post.add(data['岗位名称'])
+        set_job_category.add(data['职位类别'])
+        set_company_industry.add(data['公司行业'])
+        set_company_name.add(data['公司名称'])
+        set_company_nature.add(data['公司性质'])
+        set_city.add((data['城市']))
+    return set_post, set_job_category, set_company_industry, set_company_name, set_company_nature, set_city
+
+def set_become_dict_list(set_job):
+    list_job = []
+    for data in set_job:
+        list_job.append({'entity':[sort_format(data)]})
+    return list_job
+
 
 

+ 54 - 1
src/kg_construction/llm_construct_kg.py

@@ -4,11 +4,63 @@ from data.analyze_data.analyze_pdf import readPdf
 
 
 def get_response_vicuna(prompt):
-    content = ollama.generate(model='llama3:latest', prompt=prompt)
+    content = ollama.generate(model='qwen:14b', prompt=prompt)
     return content['response']
 
+def sort_format(entity):
+    template_result = """
+        ### 指导:
+        给你一个职位类别 {},请将其整理为标准格式,你的回答只需要包含这个结果,不需要包含其他的内容。
 
+        ### 例子1:
+        给出:电气机械,电力设备及计算机软件
+        回答:电气机械/电力设备/计算机软件
 
+        ### 例子2:
+        给出:租赁服务
+        回答:租赁服务
+
+        ### 回答:
+        """
+    respond = get_response_vicuna(template_result.format(entity))
+    return respond
+
+def entity_relation_generation(entity):
+    template_kgc = """
+        ### 指导:
+        给一个头实体 {},请根据事实知识,生成其关系与尾实体,并输出由其组成的三元组,你的回答只需要包含这个三元组:(头实体,关系,尾实体),不需要包含其他的内容。
+
+        ### 例子:
+        给出:榆林神木
+        回答:(榆林神木, 矿藏, 镁)
+
+        ### 回答:
+        """
+    respond = get_response_vicuna(template_kgc.format(entity))
+    return respond
+
+def realtion_generation(head, tail):
+    template_relation = """
+        ### 指导:
+        给一个头实体 {} 与尾实体 {},请根据事实知识,生成头实体与为实体之间的关系,你的回答只需要包含这个关系,不需要包含其他的内容。
+
+        ### 例子:
+        给出:榆林神木,镁
+        回答:矿藏
+
+        ### 回答:
+        """
+    respond = get_response_vicuna(template_relation.format(head, tail))
+    return respond
+
+
+
+
+
+
+
+
+'''
 if __name__ == "__main__":
     path = '../../data/source/机械相关专业数据汇总/专业分类数据(岗位分析套表、技能等级标准等)/460101【机械设计与制造专业】/460101【机械设计与制造专业】典型岗位的岗位标准分析套表/460101【机械设计与制造专业】典型岗位的岗位标准分析套表-济南职业学院交付产出/460101【机械设计与制造专业】典型岗位的岗位标准分析套表.pdf'
     pageNumber = 3
@@ -17,3 +69,4 @@ if __name__ == "__main__":
 
     respond = get_response_vicuna("请对以下文本提取三元组:" + context)
     print(respond)
+'''

+ 70 - 0
src/kg_construction/manager.py

@@ -0,0 +1,70 @@
+from data.analyze_data.analyze_xls import readXls, set_become_dict_list, analyze_entity
+from src.kg_construction.llm_construct_kg import get_response_vicuna, sort_format, entity_relation_generation
+from src.kg_construction.mongodb_cache import MongoDBConn
+from src.template_generate.infer_example import gen
+from tool.judge_respond_structure import judge_respond_triple_structure
+
+if __name__ == "__main__":
+    host_port = 'mongodb://8.142.150.114:27018/'
+    db_name = 'nebula-kg-cache'
+    assemble = 'mechatronics-entity-seed'
+    mongoDBConn = MongoDBConn()
+    collection = mongoDBConn.initConnect(host_port=host_port, db_name=db_name, assemble=assemble)
+
+    template_kgc = """
+            ### 指导:
+            给一个头实体 {},请根据事实知识,生成其关系与尾实体,并输出由其组成的三元组,你的回答只需要包含这个三元组:(头实体,关系,尾实体),不需要包含其他的内容。
+
+            ### 例子:
+            给出:榆林神木
+            回答:(榆林神木, 矿藏, 镁)
+
+            ### 回答:
+            """
+    template_kgc_improve = gen(template_kgc)
+    print(template_kgc_improve)
+
+
+'''
+    job_list = mongoDBConn.find_all(collection=collection)
+    for job in job_list:
+        #print(job['entity'][0])
+        triple = entity_relation_generation(job['entity'][0])
+        print(triple)
+        print(judge_respond_triple_structure(triple))
+'''
+
+
+
+
+'''
+    path = '../../data/source/机械相关专业数据汇总/爬虫数据(山东省数据)/【最终】2024_11_03_临沂职业学院-智联招聘和前程无忧数据.xls'
+    data_list = readXls(path)
+    set_post, set_job_category, set_company_industry, set_company_name, set_company_nature, set_city = analyze_entity(data_list)
+    list_job = set_become_dict_list(set_job_category)
+    print(list_job)
+    mongoDBConn.insert_dict_list(collection=collection, list_job=list_job)
+'''
+'''
+    template_kgc = """
+        ### 指导:
+        头实体 {},请根据事实知识,生成其关系与尾实体,并输出由其组成的三元组,你的回答只需要包含这个三元组,不需要包含其他的内容。
+
+        ### 例子1:
+        给出:榆林神木
+        回答:榆林神木, 矿藏, 镁
+
+        ### 回答:
+        """
+
+    job_list = mongoDBConn.find_all(collection=collection)
+    print((job_list[3]["entity"][0]))
+    respond = sort_format(job_list[3]["entity"][0])
+    print(respond)
+'''
+
+
+
+
+
+

+ 8 - 1
src/kg_construction/mongodb_cache.py

@@ -14,4 +14,11 @@ class MongoDBConn:
         mydb = myclient[db_name]
         collection = mydb[assemble]
 
-        return collection
+        return collection
+
+    def insert_dict_list(self, collection, list_job):
+        collection.insert_many(list_job)
+
+    def find_all(self, collection):
+        job_list = collection.find()
+        return job_list

+ 67 - 0
src/template_generate/infer_example.py

@@ -0,0 +1,67 @@
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import torch
+
+# TODO change model path
+model_path = '/data/codes/ycj/models/BPO'
+
+prompt_template = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]"
+
+device = 'cuda:0'
+model = AutoModelForCausalLM.from_pretrained(model_path).half().eval().to(device)
+# for 8bit
+# model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, load_in_8bit=True)
+tokenizer = AutoTokenizer.from_pretrained(model_path, add_prefix_space=True)
+    
+
+def gen(input_text):
+    prompt = prompt_template.format(input_text)
+    model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
+    output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.6, num_beams=1)
+    resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip()
+
+    #print("[Stable Optimization] ", resp)
+    return resp
+
+
+def gen_aggressive(input_text):
+    texts = [input_text] * 5  
+    responses = []
+    for text in texts:
+        seed = torch.seed()
+        torch.manual_seed(seed)
+        prompt = prompt_template.format(text)
+        min_length = len(tokenizer(prompt)['input_ids']) + len(tokenizer(text)['input_ids']) + 5
+        model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
+        bad_words_ids = [tokenizer(bad_word, add_special_tokens=False).input_ids for bad_word in ["[PROTECT]", "\n\n[PROTECT]", "[KEEP", "[INSTRUCTION]"]]
+        # eos and \n
+        eos_token_ids = [tokenizer.eos_token_id, 13]
+        output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.9, bad_words_ids=bad_words_ids, num_beams=1, eos_token_id=eos_token_ids, min_length=min_length)
+        resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].split('[KE')[0].split('[INS')[0].split('[PRO')[0].strip()
+        responses.append(resp)
+
+    for i in responses:
+        print("[Aggressive Optimization] ", i)
+
+
+text = '''This is an emotional triplet extraction, which contains the answers to three questions: which subject is being discussed, what is the emotional polarity of this subject, and why is it this emotional polarity. Here is an example:' \
+       text:Our agreed favorite is the orrechiete with sausage and chicken ( usually the waiters are kind enough to split the dish in half so you get to sample both meats ) .' \
+       label:('orrechiete with sausage and chicken', 'favorite', 'POS'), ('waiters', 'kind', 'POS') '''
+
+pro="""
+Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request, Your answer only needs to contain the words "True" or "False", no additional words or punctuation need be included, your response can only contain one word.
+### Instruction:
+Given a triple from a knowledge graph. Each triple consists of a head entity, a relation, and a tail entity. Please determine the correctness of the triple and response "True" or "False".
+
+### Input:
+{}
+
+### Response:
+
+"""
+# Stable optimization, this will sometimes maintain the original prompt
+# gen(text)
+#gen(pro)
+# Agressive optimization, this will refine the original prompt with a higher possibility
+# but there may be inappropriate changes
+# gen_aggressive(text)
+#gen_aggressive(pro)

+ 8 - 0
test/analyze_xls_test.py

@@ -0,0 +1,8 @@
+from data.analyze_data.analyze_xls import readXls, analyze_entity, set_become_dict_list
+
+if __name__ == "__main__":
+    path = '../data/source/机械相关专业数据汇总/爬虫数据(山东省数据)/【最终】2024_11_03_临沂职业学院-智联招聘和前程无忧数据.xls'
+    data_list = readXls(path)
+    set_post, set_job_category, set_company_industry, set_company_name, set_company_nature, set_city = analyze_entity(data_list)
+    list_job = set_become_dict_list(set_post)
+    print(list_job)

+ 2 - 2
test/mangodb_test.py

@@ -3,9 +3,9 @@ from src.kg_construction.mongodb_cache import MongoDBConn
 if __name__ == "__main__":
     host_port = 'mongodb://8.142.150.114:27018/'
     db_name = 'nebula-kg-cache'
-    assemble = 'mechatronics-seed'
+    assemble = 'mechatronics-entity-seed'
     mongoDBConn = MongoDBConn()
     collection = mongoDBConn.initConnect(host_port=host_port, db_name=db_name, assemble=assemble)
-    triple_sample_templete = {"name":"ycj"}
+    triple_sample_templete = {'name': 'MongoDB', 'type': 'database'}
     collection.insert_one(triple_sample_templete)
     print(collection)

+ 7 - 0
tool/judge_respond_structure.py

@@ -0,0 +1,7 @@
+
+
+def judge_respond_triple_structure(respond):
+    if respond[0] == '(' and respond[len(respond)-1] == ')' and respond.count(',') == 2:
+        return True
+    else:
+        return False